diff --git a/mlir/docs/Traits/Broadcastable.md b/mlir/docs/Traits/Broadcastable.md
new file mode 100644
--- /dev/null
+++ b/mlir/docs/Traits/Broadcastable.md
@@ -0,0 +1,199 @@
+# The `Broadcastable` Trait
+
+[TOC]
+
+## Description
+
+The `Broadcastable` trait enforces the following properties on an operation:
+
+- The operation has at least one input operand.
+
+- The operation has exactly one result.
+
+- All input operands and result are of type `tensor` or `vector`.
+
+- A shape inference mechanism is able to compute the result shape solely based on input operand shapes.
+
+- Input operands have broadcast-compatible shapes, according to the verification rules presented below.
+
+- The operation's result shape is compatible with —though not necessarily identical to— the shape inferred from its input operands, according to the verification rules presented below.
+
+
+## Dimension inference
+
+Given an operation with two input operands, the size of dimension *i* of its result can be inferred from dimension *i* of the operands according to the table below. Here, `dim0` and `dim1` represent dimension *i* of the input operands in an interchangeable order, while `inferredDim` represents the inferred size for dimension *i* of the operation result. Dimensions are classified in three categories: dynamic ("?"), static equal to 1 ("1"), and static greater than 1 (">1").
+
+
+| `dim0` | `dim1` | `inferredDim` | Notes |
+| -------- | -------- | ------------- | ----- |
+| ? | ? | ? | The operation produces undefined behavior if `RuntimeSize(dim0)` != `RuntimeSize(dim1)`. |
+| ? | 1 | ? | Dimension `dim1` is broadcast to `RuntimeSize(dim0)`. |
+| ? | >1 | `dim1` | The operation produces undefined behavior if `RuntimeSize(dim0)` != `dim1`. |
+| 1 | 1 | 1 | |
+| 1 | >1 | `dim1` | Dimension `dim0` is broadcast to `dim1`. |
+| >1 | >1 | `dim0` | The operation verifier produces a compile-time error if `dim0` != `dim1`. |
+
+
+As reflected in this table, a dimension may only be broadcast when it is static and explicitly set to size 1. Broadcast semantics do not apply to dynamic dimensions with a runtime size of 1. This restriction is aimed at easing the lowering process of broadcastable operations. A lowering pass is exempt from emitting logic that selectively broadcasts dynamic dimensions according to their runtime size.
+
+The following pseudo-function is a formal representation of the dimension inference process:
+
+```python
+InferDim(dim0, dim1):
+ switch (dim0, dim1):
+ case (?, ?):
+ case (?, 1):
+ case (1, 1):
+ case (>1, ?):
+ case (>1, 1):
+ return dim0
+ case (?, >1):
+ case (1, ?):
+ case (1, >1):
+ return dim1
+ case (>1, >1):
+ ERROR_IF(dim0 != dim1)
+ return dim0
+```
+
+## Shape inference
+
+The shape inference process begins by correcting rank differences in input operands. A shape is expanded by adding additional dimensions of size 1 on its left until the desired rank is reached, as shown here:
+
+```python
+ExpandRank(shape, rank):
+ while len(shape) < rank:
+ shape.prepend(1)
+```
+
+Given the shapes of two ranked input operands, the result's shape is inferred by equalizing input ranks and inferring individual dimensions, as shown here:
+
+```python
+InferShape(shape0, shape1):
+
+ # Equalize ranks
+ rank = max(GetRank(shape0), GetRank(shape1))
+ ExpandRank(shape0, rank)
+ ExpandRank(shape1, rank)
+
+ # Infer shape
+ inferredShape = []
+ for (dim0, dim1) in zip(shape0, shape1):
+ inferredDim = InferDim(dim0, dim1)
+ inferredShape.append(inferredDim)
+ return inferredShape
+```
+
+The result shape for an operation with an arbitrary number of input operands is then inferred by discarding unranked operands, applying shape inference on the first ranked operand pair, and updating the inferred shape with each additional ranked operand. If the operation has no ranked operands, the result shape cannot be inferred. If the operation has exactly one ranked operand, its shape is directly provided as the inferred result shape. Formally:
+
+```python
+InferResultShape(op):
+
+ # Filter ranked operands
+ rankedOperands = filter(op.operands, IsRanked)
+ if len(rankedOperands) == 0:
+ return None
+
+ # Infer result shape
+ inferredShape = GetShape(rankedOperands[0])
+ for operand in rankedOperands[1:]:
+ inferredShape = InferShape(inferredShape, GetShape(operand))
+ return inferredShape
+```
+
+## Verification
+
+The legality of an operation with the `Broadcastable` trait is verified by first running the shape inference process. If a failure occurs during shape inference, it is concluded that input operands are not broadcast-compatible, and verification fails. If shape inference succeeds, verification continues.
+
+If either the result is unranked or all input operands are unranked, no further verification steps are needed, and the process ends here successfully. If, on the contrary, both the result and at least one input operand are ranked, verification continues by checking for a matching rank between the previously inferred shape and the result.
+
+Once a rank match is guaranteed, each dimension of the inferred shape is compared with the corresponding dimension of the actual result shape according to the following table table:
+
+
+| `inferredDim` | `actualDim` | Verification outcome |
+| ------------- | ----------- | -------------------- |
+| ? | ? | **OK** |
+| ? | static | **Error**
An inferred dimension being dynamic indicates that its size cannot be inferred at compile time from its input operands. The presence of a static dimension in the actual result is counterintuitive and is therefore not allowed. |
+| static | ? | **OK**
The actual result dimension may be dynamic even when a static size can be inferred at compile time. The programmer may choose to relax the specificity of the result dimension for forward compatibility of the result type. |
+| static | static | **OK if equal**
When both the inferred and actual dimensions are static, they must be set to the same size. |
+
+
+The full verification process can be formally specified as follows:
+
+```python
+Verify(op):
+
+ # Run shape inference
+ inferredShape = InferResultShape(op.operands)
+
+ # Done if result is unranked or all operands are unranked
+ if not IsRanked(op.result) or inferredShape is None:
+ return
+
+ # Rank must match
+ actualShape = GetShape(op.result):
+ ERROR_IF(len(inferredShape) != len(actualShape))
+
+ # Verify
+ for (inferredDim, actualDim) in zip(inferredShape, actualShape):
+ ERROR_IF(IsDynamic(inferredDim) and IsStatic(actualDim))
+ ERROR_IF(IsStatic(actualDim) and inferredDim != actualDim)
+```
+
+## Examples
+
+The following are correct uses of broadcastable ops:
+
+```mlir
+// Exact match of static sizes.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<1x2xi32) -> tensor<1x2xi32>
+
+// Dynamic sizes match. The programmer must guarantee that the runtime sizes of
+// %arg0 and %arg1 are equal at runtime.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor, tensor tensor
+
+// The shape of %arg0 is broadcast from tensor<1xi32> to tensor<4xi32>.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<1xi32>, tensor<4xi32) -> tensor<4xi32>
+
+// The shape of %result is inferred as tensor<4xi32>, while the actual result
+// type is tensor. The inferred shape is compatible with the actual shape.
+%result = "test.broadcastable"(%arg0) : (tensor<4xi32) -> tensor
+
+// The shape of %arg0 is first expanded to tensor<1x1x4xi32> and then broadcast
+// to tensor<2x3x4xi32>.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<4xi32>, tensor<2x3x4xi32) -> tensor<2x3x4xi32>
+
+// Input and results tensors have different element types (i1, i32, i64). The
+// 'Broadcastable' trait has no restrictions on element types.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<2xi1>, tensor<2xi32) -> tensor<2xi64>
+
+// No result shape verification is needed when the result is unranked.
+%result = "test.broadcastable"(%arg0) : (tensor<2xi32>) -> tensor<*xi32>
+
+// No result shape verification needed when all inputs are unranked.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<2xi32>
+```
+
+
+The following are incorrect uses of broadcastable ops:
+
+```mlir
+// Dimension 0 of input operands is static but not equal.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32) -> tensor
+
+// The inferred result shape is tensor<3xi32>, but the actual result shape is
+// tensor<1x3xi32>. Inferred and actual shapes differ in rank.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<3xi32>, tensor<3xi32) -> tensor<1x3xi32>
+
+// The inferred result shape is tensor, but the actual shape is
+// tensor<4xi32>. The inferred shape is not compatible with the actual shape.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor, tensor tensor<4xi32>
+
+// The inferred result shape is tensor<2xi32>, but the actual result shape is
+// tensor<4xi32>, which is not compatible.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32) -> tensor<4xi32>
+
+// The inferred result shape is tensor<1xi32>, but the actual result shape is
+// tensor<4xi32>. Broadcast semantics are not applicable for results.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32) -> tensor<4xi32>
+```
diff --git a/mlir/docs/Traits.md b/mlir/docs/Traits/_index.md
rename from mlir/docs/Traits.md
rename to mlir/docs/Traits/_index.md
--- a/mlir/docs/Traits.md
+++ b/mlir/docs/Traits/_index.md
@@ -225,16 +225,7 @@
This trait adds the property that the operation is known to have
[broadcast-compatible](https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-operands and its result types' shape is the broadcast compatible with the shape
-of the broadcasted operands. Specifically, starting from the most varying
-dimension, each dimension pair of the two operands' shapes should either be the
-same or one of them is one. Also, the result shape should have the corresponding
-dimension equal to the larger one, if known. Shapes are checked partially if
-ranks or dimensions are not known. For example, an op with `tensor` and
-`tensor<2xf32>` as operand types and `tensor<3x2xf32>` as the result type is
-broadcast-compatible.
-
-This trait requires that the operands are either vector or tensor types.
+operands and that its result type is compatible with the inferred broadcast shape. See [The `Broadcastable` Trait](Traits/Broadcastable.md) for details.
### Commutative
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//
+#include "llvm/ADT/Sequence.h"
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -25,6 +26,7 @@
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -518,112 +520,147 @@
return nullptr;
}
-static LogicalResult
-elementwiseMatchAndRewriteHelper(Operation *operation,
- PatternRewriter &rewriter) {
- auto loc = operation->getLoc();
-
- assert(operation->getNumResults() == 1 &&
- "All TOSA elementwise ops should only return a single result.");
-
- auto result = operation->getResult(0);
- auto resultTy = dyn_cast(result.getType());
-
- if (!resultTy)
- return rewriter.notifyMatchFailure(
- operation, "All results must be a ranked tensor type");
-
- unsigned rank = resultTy.getRank();
-
- // Construct the indexing maps needed for linalg.generic ops.
- SmallVector bodyArgTypes;
+static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
+ int64_t rank) {
+ // No need to expand if we are already at the desired rank
+ auto shapedType = dyn_cast(tensor.getType());
+ assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
+ int64_t numExtraDims = rank - shapedType.getRank();
+ assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank");
+ if (!numExtraDims)
+ return tensor;
+
+ // Compute reassociation indices
+ SmallVector> reassociationIndices(shapedType.getRank());
+ int64_t index = 0;
+ for (index = 0; index <= numExtraDims; index++)
+ reassociationIndices[0].push_back(index);
+ for (size_t position = 1; position < reassociationIndices.size(); position++)
+ reassociationIndices[position].push_back(index++);
+
+ // Compute result type
+ SmallVector resultShape;
+ for (index = 0; index < numExtraDims; index++)
+ resultShape.push_back(1);
+ for (auto size : shapedType.getShape())
+ resultShape.push_back(size);
+ auto resultType = RankedTensorType::get(resultShape, shapedType.getElementType());
+
+ // Emit 'tensor.expand_shape' op
+ return rewriter.create(
+ loc, resultType, tensor, reassociationIndices);
+}
- for (Value in : operation->getOperands())
- bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
+static Value
+getDominantValueForDim(Value lhs, Value rhs, unsigned index) {
+ auto lhsDimSize = lhs.getType().cast().getDimSize(index);
+ auto rhsDimSize = rhs.getType().cast().getDimSize(index);
+ if ((ShapedType::isDynamic(lhsDimSize) && rhsDimSize > 1) ||
+ (lhsDimSize == 1 && ShapedType::isDynamic(rhsDimSize)) ||
+ (lhsDimSize == 1 && rhsDimSize > 1))
+ return rhs;
+ return lhs;
+}
- SmallVector opResultTypes;
- SmallVector emptyTensors;
+static Value
+getDominantValueForDim(ValueRange values, unsigned index) {
+ auto dominantValue = values.front();
+ for (auto value : values.drop_front())
+ dominantValue = getDominantValueForDim(dominantValue, value, index);
+ return dominantValue;
+}
- SmallVector dynDims;
- dynDims.resize(rank);
+static OpFoldResult
+getTensorDim(PatternRewriter &rewriter, Location loc, Value tensor, int64_t index) {
+ auto shapedType = dyn_cast(tensor.getType());
+ assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
+ assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
+ if (shapedType.isDynamicDim(index))
+ return rewriter.create(loc, tensor, index).getResult();
+ return rewriter.getIndexAttr(shapedType.getDimSize(index));
+}
- for (auto arg : operation->getOperands()) {
- auto operandTy = cast(arg.getType());
- for (int i = 0; i < operandTy.getRank(); i++) {
- if (operandTy.isDynamicDim(i) && !dynDims[i])
- dynDims[i] = rewriter.create(loc, arg, i);
- }
+static Value createOutputTensor(PatternRewriter &rewriter, Location loc,
+ ValueRange values, Type elementType) {
+ SmallVector shape;
+ auto rank = values.front().getType().cast().getRank();
+ for (auto index : llvm::seq(0, rank)) {
+ auto dominantValue = getDominantValueForDim(values, index);
+ auto dim = getTensorDim(rewriter, loc, dominantValue, index);
+ shape.push_back(dim);
}
+ return rewriter.create(loc, shape, elementType);
+}
- SmallVector filteredDims = condenseValues(dynDims);
-
- emptyTensors.push_back(
- rewriter.create(loc, resultTy, filteredDims));
- opResultTypes.push_back(result.getType());
-
- auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
- emptyTensors, [](Value v) { return getElementTypeOrSelf(v); }));
-
- SmallVector operands;
- SmallVector indexingMaps;
- indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
-
- // Input indexing maps may be broadcasted.
- for (Value operand : operation->getOperands()) {
- ShapedType type = cast(operand.getType());
-
- if (type.getShape() == resultTy.getShape()) {
- operands.push_back(operand);
- indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
- continue;
- }
+static bool
+operandsAndResultsRanked(Operation* operation) {
+ auto isRanked = [](Value value) { return isa(value.getType()); };
+ return llvm::all_of(operation->getOperands(), isRanked) &&
+ llvm::all_of(operation->getResults(), isRanked);
+}
- SmallVector newShape;
- SmallVector affineExprs;
- newShape.reserve(type.getRank());
- for (const auto &it : llvm::enumerate(type.getShape())) {
- if (it.value() == resultTy.getDimSize(it.index())) {
- newShape.push_back(it.value());
- affineExprs.push_back(
- mlir::getAffineDimExpr(it.index(), rewriter.getContext()));
- }
- }
+static LogicalResult
+elementwiseMatchAndRewriteHelper(Operation *operation,
+ PatternRewriter &rewriter) {
- if (newShape.size() != rank) {
- operand = rewriter.create(
- loc, RankedTensorType::get(newShape, type.getElementType()), operand,
- rewriter.getDenseI64ArrayAttr(newShape));
+ // Collect op properties
+ assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
+ assert(operation->getNumOperands() >= 1 && "elementwise op expects at least 1 operand");
+ auto loc = operation->getLoc();
+ auto result = operation->getResult(0);
+ auto resultType = result.getType().cast();
+
+ // Check supported features for this pass
+ if (!operandsAndResultsRanked(operation))
+ return rewriter.notifyMatchFailure(operation, "Unranked tensors not supported");
+
+ // Equalize input ranks
+ auto rank = resultType.getRank();
+ auto expandedOperands = llvm::map_to_vector(operation->getOperands(), [&](Value operand) {
+ return expandRank(rewriter, loc, operand, rank);
+ });
+
+ // Create output tensor
+ auto outputTensor = createOutputTensor(
+ rewriter, loc, expandedOperands, resultType.getElementType());
+
+ // Build affine maps
+ auto affineMaps = llvm::map_to_vector(expandedOperands, [&](auto operand) {
+ auto shape = cast(operand.getType()).getShape();
+ SmallVector affineExprs;
+ for (auto it : llvm::enumerate(shape)) {
+ auto affineExpr = it.value() == 1 ?
+ rewriter.getAffineConstantExpr(0) :
+ rewriter.getAffineDimExpr(it.index());
+ affineExprs.push_back(affineExpr);
}
+ return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
+ });
+ affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
- operands.push_back(operand);
- indexingMaps.push_back(AffineMap::get(
- /*dimCount=*/rank, /*symbolCount=*/0, affineExprs,
- rewriter.getContext()));
- }
-
- indexingMaps.append(operation->getNumResults(),
- rewriter.getMultiDimIdentityMap(rank));
-
- bool didEncounterError = false;
+ // Emit 'linalg.generic' op
+ bool encounteredError = false;
auto linalgOp = rewriter.create(
- loc, opResultTypes, operands, emptyTensors, indexingMaps,
+ loc, outputTensor.getType(), expandedOperands, outputTensor, affineMaps,
getNParallelLoopsAttrs(rank),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
Value opResult = createLinalgBodyCalculationForElementwiseOp(
operation, blockArgs.take_front(operation->getNumOperands()),
- bodyResultTypes, rewriter);
+ {resultType.getElementType()}, rewriter);
if (!opResult) {
- didEncounterError = true;
+ encounteredError = true;
return;
}
nestedBuilder.create(loc, opResult);
});
-
- if (didEncounterError)
+ if (encounteredError)
return rewriter.notifyMatchFailure(
operation, "unable to create linalg.generic body for elementwise op");
-
- rewriter.replaceOp(operation, linalgOp->getResults());
+
+ // Cast 'linalg.generic' result into original result type if needed
+ auto castResult = rewriter.createOrFold(
+ loc, resultType, linalgOp->getResult(0));
+ rewriter.replaceOp(operation, castResult);
return success();
}
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -195,18 +195,22 @@
static bool isCompatibleInferredReturnShape(ArrayRef inferred,
ArrayRef existing) {
- auto isCompatible = [](int64_t dim1, int64_t dim2) {
- // If the inferred and existing dim is the same, or one of them is unknown
- // then it is compatible, else if the inferred dim is 1 then it is also
- // compatible. But if the existing dim is 1 and the inferred is greater than
- // 1 then flag.
- return dim1 == dim2 || ShapedType::isDynamic(dim1) ||
- ShapedType::isDynamic(dim2) || dim1 == 1;
+ auto isCompatible = [](int64_t inferredDim, int64_t existingDim) {
+ // The following criterion is used to determine the validity of an existing
+ // dimension:
+ //
+ // inferredDim existingDim Behavior
+ // ----------- ----------- --------
+ // dynamic dynamic OK
+ // dynamic static Error
+ // static dynamic OK
+ // static static OK if equal
+ return ShapedType::isDynamic(existingDim) || inferredDim == existingDim;
};
if (inferred.size() != existing.size())
return false;
- for (auto p : llvm::zip(inferred, existing))
- if (!isCompatible(std::get<0>(p), std::get<1>(p)))
+ for (auto [inferredDim, existingDim] : llvm::zip(inferred, existing))
+ if (!isCompatible(inferredDim, existingDim))
return false;
return true;
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1,89 +1,148 @@
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s -verify-diagnostics -o -| FileCheck %s
+
// CHECK: #[[$MAP0:.*]] = affine_map<() -> ()>
-// CHECK-LABEL: @test_abs
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_abs(%arg0: tensor) -> tensor {
+// CHECK-LABEL: @test_abs_scalar
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+func.func @test_abs_scalar(%arg0: tensor) -> tensor {
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = []} ins(%[[ARG0]] : tensor) outs([[INIT]] : tensor) {
- // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32):
- // CHECK: [[ELEMENT:%.+]] = math.absf %[[ARG1]]
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = []} ins([[ARG0]] : tensor) outs([[INIT]] : tensor) {
+ // CHECK: ^bb0([[ARG1:%.*]]: f32, [[ARG2:%.*]]: f32):
+ // CHECK: [[ELEMENT:%.*]] = math.absf [[ARG1]] : f32
// CHECK: linalg.yield [[ELEMENT]] : f32
// CHECK: } -> tensor
+ %0 = "tosa.abs"(%arg0) : (tensor) -> tensor
- %0 = "tosa.abs"(%arg0) : (tensor) -> tensor
+ // CHECK: return [[GENERIC]] : tensor
+ return %0 : tensor
+}
- // CHECK: return [[GENERIC]]
- return %0 : tensor
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL: @test_abs_static_dynamic
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+func.func @test_abs_static_dynamic(%arg0: tensor<3x?xf32>) -> tensor<3x?xf32> {
+ // CHECK: [[ONE:%.+]] = arith.constant 1 : index
+ // CHECK: [[DIM:%.+]] = tensor.dim [[ARG0]], [[ONE]] : tensor<3x?xf32>
+ // CHECK: [[INIT:%.+]] = tensor.empty([[DIM]]) : tensor<3x?xf32>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[ARG0]] : tensor<3x?xf32>) outs([[INIT]] : tensor<3x?xf32>) {
+ // CHECK: ^bb0([[ARG1:%.*]]: f32, [[ARG2:%.*]]: f32):
+ // CHECK: [[ELEMENT:%.*]] = math.absf [[ARG1]] : f32
+ // CHECK: linalg.yield [[ELEMENT]] : f32
+ // CHECK: } -> tensor<3x?xf32>
+ %0 = "tosa.abs"(%arg0) : (tensor<3x?xf32>) -> tensor<3x?xf32>
+
+ // CHECK: return [[GENERIC]] : tensor<3x?xf32>
+ return %0 : tensor<3x?xf32>
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
-// CHECK-LABEL: @test_abs
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+// CHECK-LABEL: @test_abs_cast_result
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+func.func @test_abs_cast_result(%arg0: tensor<2xf32>) -> tensor {
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xf32>
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) {
- // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32):
- // CHECK: [[ELEMENT:%.+]] = math.absf %[[ARG1]]
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]] : tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) {
+ // CHECK: ^bb0([[ARG1:%.*]]: f32, [[ARG2:%.*]]: f32):
+ // CHECK: [[ELEMENT:%.*]] = math.absf [[ARG1]] : f32
// CHECK: linalg.yield [[ELEMENT]] : f32
// CHECK: } -> tensor<2xf32>
- %0 = "tosa.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ %0 = "tosa.abs"(%arg0) : (tensor<2xf32>) -> tensor
- // CHECK: return [[GENERIC]]
- return %0 : tensor<2xf32>
+ // CHECK: [[CAST:%.+]] = tensor.cast [[GENERIC]] : tensor<2xf32> to tensor
+ // CHECK: return [[CAST]] : tensor
+ return %0 : tensor
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-LABEL: @test_abs
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_abs(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
- // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2x3xf32>
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xf32>) outs([[INIT]] : tensor<2x3xf32>) {
- // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32):
- // CHECK: [[ELEMENT:%.+]] = math.absf %[[ARG1]]
+// CHECK-LABEL: @test_add_all_static
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]:
+// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]:
+func.func @test_add_all_static(%arg0: tensor<3x5xf32>, %arg1: tensor<3x5xf32>) -> tensor<3x5xf32> {
+ // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<3x5xf32>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[ARG0]], [[ARG1]] : tensor<3x5xf32>, tensor<3x5xf32>) outs([[INIT]] : tensor<3x5xf32>) {
+ // CHECK: ^bb0([[ARG2:%.*]]: f32, [[ARG3:%.*]]: f32, [[ARG4:%.*]]: f32):
+ // CHECK: [[ELEMENT:%.*]] = arith.addf [[ARG2]], [[ARG3]] : f32
// CHECK: linalg.yield [[ELEMENT]] : f32
- // CHECK: } -> tensor<2x3xf32>
- %0 = "tosa.abs"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
-
- // CHECK: return [[GENERIC]]
- return %0 : tensor<2x3xf32>
+ // CHECK: } -> tensor<3x5xf32>
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32>
+
+ // CHECK: return [[GENERIC]] : tensor<3x5xf32>
+ return %0 : tensor<3x5xf32>
}
// -----
-// CHECK-LABEL: @test_abs
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_abs(%arg0: tensor) -> tensor {
- // CHECK: %[[C0:.+]] = arith.constant 0
- // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
- // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]])
- // CHECK: linalg.generic
- // CHECK: math.absf
- %0 = "tosa.abs"(%arg0) : (tensor) -> tensor
- return %0 : tensor
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL: @test_add_all_dynamic
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]:
+// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]:
+func.func @test_add_all_dynamic(%arg0: tensor, %arg1: tensor) -> tensor {
+ // CHECK-DAG: [[ZERO:%.+]] = arith.constant 0 : index
+ // CHECK-DAG: [[ONE:%.+]] = arith.constant 1 : index
+ // CHECK-DAG: [[DIM0:%.+]] = tensor.dim [[ARG0]], [[ZERO]] : tensor
+ // CHECK-DAG: [[DIM1:%.+]] = tensor.dim [[ARG0]], [[ONE]] : tensor
+ // CHECK: [[INIT:%.+]] = tensor.empty([[DIM0]], [[DIM1]]) : tensor
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[ARG0]], [[ARG1]] : tensor, tensor) outs([[INIT]] : tensor) {
+ // CHECK: ^bb0([[ARG2:%.*]]: f32, [[ARG3:%.*]]: f32, [[ARG4:%.*]]: f32):
+ // CHECK: [[ELEMENT:%.*]] = arith.addf [[ARG2]], [[ARG3]] : f32
+ // CHECK: linalg.yield [[ELEMENT]] : f32
+ // CHECK: } -> tensor
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor, tensor) -> tensor
+
+ // CHECK: return [[GENERIC]] : tensor
+ return %0 : tensor
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-LABEL: @test_abs_dyn
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_abs_dyn(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> {
- // CHECK: %[[C1:.+]] = arith.constant 1
- // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]]
- // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]])
- // CHECK: linalg.generic
- // CHECK: math.absf
- %0 = "tosa.abs"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32>
- return %0 : tensor<2x?xf32>
+// CHECK-LABEL: @test_add_static_dynamic
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]:
+// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]:
+func.func @test_add_static_dynamic(%arg0: tensor, %arg1: tensor<2x?xf32>) -> tensor<2x3xf32> {
+ // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2x3xf32>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[ARG0]], [[ARG1]] : tensor, tensor<2x?xf32>) outs([[INIT]] : tensor<2x3xf32>) {
+ // CHECK: ^bb0([[ARG2:%.*]]: f32, [[ARG3:%.*]]: f32, [[ARG4:%.*]]: f32):
+ // CHECK: [[ELEMENT:%.*]] = arith.addf [[ARG2]], [[ARG3]] : f32
+ // CHECK: linalg.yield [[ELEMENT]] : f32
+ // CHECK: } -> tensor<2x3xf32>
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor, tensor<2x?xf32>) -> tensor<2x3xf32>
+
+ // CHECK: return [[GENERIC]] : tensor<2x3xf32>
+ return %0 : tensor<2x3xf32>
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (0, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, 0)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL: @test_add_broadcast_to_static
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]:
+// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]:
+func.func @test_add_broadcast_to_static(%arg0: tensor<1x5xf32>, %arg1: tensor<3x1xf32>) -> tensor<3x5xf32> {
+ // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<3x5xf32>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins([[ARG0]], [[ARG1]] : tensor<1x5xf32>, tensor<3x1xf32>) outs([[INIT]] : tensor<3x5xf32>) {
+ // CHECK: ^bb0([[ARG2:%.*]]: f32, [[ARG3:%.*]]: f32, [[ARG4:%.*]]: f32):
+ // CHECK: [[ELEMENT:%.*]] = arith.addf [[ARG2]], [[ARG3]] : f32
+ // CHECK: linalg.yield [[ELEMENT]] : f32
+ // CHECK: } -> tensor<3x5xf32>
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x5xf32>, tensor<3x1xf32>) -> tensor<3x5xf32>
+
+ // CHECK: return [[GENERIC]] : tensor<3x5xf32>
+ return %0 : tensor<3x5xf32>
}
// -----
@@ -100,68 +159,122 @@
// -----
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> ()>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>
-
-// CHECK-LABEL: @test_broadcast
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: tensor<1xf32
-// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<2xf32>
-func.func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
- // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xf32>
- // CHECK: [[RESHAPE:%.+]] = "tosa.reshape"(%[[ARG0]])
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins([[RESHAPE]], %[[ARG1]] : tensor, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) {
- // CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32):
- // CHECK: [[ELEMENT:%.+]] = arith.addf %[[ARG2]], %[[ARG3]] : f32
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (0, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, 0)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL: @test_add_broadcast_to_dynamic
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]:
+// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]:
+func.func @test_add_broadcast_to_dynamic(%arg0: tensor<1x?xf32>, %arg1: tensor) -> tensor {
+ // CHECK-DAG: [[ZERO:%.+]] = arith.constant 0 : index
+ // CHECK-DAG: [[ONE:%.+]] = arith.constant 1 : index
+ // CHECK-DAG: [[DIM0:%.+]] = tensor.dim [[ARG1]], [[ZERO]] : tensor
+ // CHECK-DAG: [[DIM1:%.+]] = tensor.dim [[ARG0]], [[ONE]] : tensor<1x?xf32>
+ // CHECK: [[INIT:%.+]] = tensor.empty([[DIM0]], [[DIM1]]) : tensor
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins([[ARG0]], [[ARG1]] : tensor<1x?xf32>, tensor) outs([[INIT]] : tensor) {
+ // CHECK: ^bb0([[ARG2:%.*]]: f32, [[ARG3:%.*]]: f32, [[ARG4:%.*]]: f32):
+ // CHECK: [[ELEMENT:%.*]] = arith.addf [[ARG2]], [[ARG3]] : f32
// CHECK: linalg.yield [[ELEMENT]] : f32
- // CHECK: } -> tensor<2xf32>
- %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<2xf32>) -> tensor<2xf32>
- return %0 : tensor<2xf32>
+ // CHECK: } -> tensor
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x?xf32>, tensor) -> tensor
+
+ // CHECK: return [[GENERIC]] : tensor
+ return %0 : tensor
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> ()>
-// CHECK-LABEL: @test_broadcast_swapped_args
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: tensor<2xf32
-// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<1xf32>
-func.func @test_broadcast_swapped_args(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>) -> tensor<2xf32> {
+// CHECK-LABEL: @test_add_cast_result
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]:
+// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]:
+func.func @test_add_cast_result(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor {
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xf32>
- // CHECK: [[RESHAPE:%.+]] = "tosa.reshape"(%[[ARG1]])
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]], [[RESHAPE]] : tensor<2xf32>, tensor) outs([[INIT]] : tensor<2xf32>) {
- // CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32):
- // CHECK: [[ELEMENT:%.+]] = arith.addf %[[ARG2]], %[[ARG3]] : f32
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[ARG1]] : tensor<2xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) {
+ // CHECK: ^bb0([[ARG2:%.*]]: f32, [[ARG3:%.*]]: f32, [[ARG4:%.*]]: f32):
+ // CHECK: [[ELEMENT:%.*]] = arith.addf [[ARG2]], [[ARG3]] : f32
// CHECK: linalg.yield [[ELEMENT]] : f32
// CHECK: } -> tensor<2xf32>
- %0 = "tosa.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<1xf32>) -> tensor<2xf32>
- return %0 : tensor<2xf32>
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor
+
+ // CHECK: [[CAST:%.+]] = tensor.cast [[GENERIC]] : tensor<2xf32> to tensor
+ // CHECK: return [[CAST]] : tensor
+ return %0 : tensor
}
// -----
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
-// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (0, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-LABEL: @test_multibroadcast
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]
-func.func @test_multibroadcast(%arg0: tensor<1x3xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x3xf32> {
+// CHECK-LABEL: @test_add_expand_rank
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]:
+// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]:
+func.func @test_add_expand_rank(%arg0: tensor<3xf32>, %arg1: tensor<2x3xf32>) -> tensor<2x3xf32> {
+ // CHECK: [[EXPANDED:%.+]] = tensor.expand_shape [[ARG0]] {{\[\[}}0, 1]] : tensor<3xf32> into tensor<1x3xf32>
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2x3xf32>
- // CHECK: [[RESHAPE1:%.+]] = "tosa.reshape"(%[[ARG0]]) <{new_shape = array}
- // CHECK: [[RESHAPE2:%.+]] = "tosa.reshape"(%[[ARG1]]) <{new_shape = array}
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[RESHAPE1]], [[RESHAPE2]] : tensor<3xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2x3xf32>) {
- // CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32):
- // CHECK: [[ELEMENT:%.+]] = arith.addf %[[ARG2]], %[[ARG3]] : f32
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins([[EXPANDED]], [[ARG1]] : tensor<1x3xf32>, tensor<2x3xf32>) outs([[INIT]] : tensor<2x3xf32>) {
+ // CHECK: ^bb0([[ARG2:%.*]]: f32, [[ARG3:%.*]]: f32, [[ARG4:%.*]]: f32):
+ // CHECK: [[ELEMENT:%.*]] = arith.addf [[ARG2]], [[ARG3]] : f32
// CHECK: linalg.yield [[ELEMENT]] : f32
// CHECK: } -> tensor<2x3xf32>
- %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+
+ // CHECK: return [[GENERIC]] : tensor<2x3xf32>
return %0 : tensor<2x3xf32>
}
// -----
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (0, d1, d2)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK-LABEL: @test_select
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]:
+// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]:
+// CHECK-SAME: [[ARG2:%[0-9a-zA-Z_]*]]:
+func.func @test_select(%arg0: tensor<1x?x?xi1>, %arg1: tensor, %arg2: tensor) -> tensor {
+ // CHECK-DAG: [[ZERO:%.+]] = arith.constant 0 : index
+ // CHECK-DAG: [[DIM0:%.+]] = tensor.dim [[ARG1]], [[ZERO]] : tensor
+ // CHECK: [[INIT:%.+]] = tensor.empty([[DIM0]]) : tensor
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins([[ARG0]], [[ARG1]], [[ARG2]] : tensor<1x?x?xi1>, tensor, tensor) outs([[INIT]] : tensor) {
+ // CHECK: ^bb0([[ARG3:%.+]]: i1, [[ARG4:%.+]]: f32, [[ARG5:%.+]]: f32, [[ARG6:%.+]]: f32):
+ // CHECK: [[ELEMENT:%.+]] = arith.select [[ARG3]], [[ARG4]], [[ARG5]] : f32
+ // CHECK: linalg.yield [[ELEMENT]] : f32
+ // CHECK: } -> tensor
+ %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x?x?xi1>, tensor, tensor) -> tensor
+
+ // CHECK: return [[GENERIC]] : tensor
+ return %0 : tensor
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (0, 0, d2)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (0, d1, d2)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK-LABEL: @test_select_expand_rank
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]:
+// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]:
+// CHECK-SAME: [[ARG2:%[0-9a-zA-Z_]*]]:
+func.func @test_select_expand_rank(%arg0: tensor<4xi1>, %arg1: tensor<3x4xf32>, %arg2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+ // CHECK: [[EXPANDED0:%.+]] = tensor.expand_shape [[ARG0]] {{\[\[}}0, 1, 2]] : tensor<4xi1> into tensor<1x1x4xi1>
+ // CHECK: [[EXPANDED1:%.+]] = tensor.expand_shape [[ARG1]] {{\[\[}}0, 1], [2]] : tensor<3x4xf32> into tensor<1x3x4xf32>
+ // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2x3x4xf32>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel"]} ins([[EXPANDED0]], [[EXPANDED1]], [[ARG2]] : tensor<1x1x4xi1>, tensor<1x3x4xf32>, tensor<2x3x4xf32>) outs([[INIT]] : tensor<2x3x4xf32>) {
+ // CHECK: ^bb0([[ARG3:%.+]]: i1, [[ARG4:%.+]]: f32, [[ARG5:%.+]]: f32, [[ARG6:%.+]]: f32):
+ // CHECK: [[ELEMENT:%.+]] = arith.select [[ARG3]], [[ARG4]], [[ARG5]] : f32
+ // CHECK: linalg.yield [[ELEMENT]] : f32
+ // CHECK: } -> tensor<2x3x4xf32>
+ %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<4xi1>, tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
+ return %0 : tensor<2x3x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: @test_simple_f32
func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
// CHECK: linalg.generic
@@ -1412,20 +1525,6 @@
// -----
-// Regression test for using the wrong rank.
-
-// CHECK-DAG: affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
-// CHECK-DAG: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-DAG: affine_map<(d0, d1, d2, d3) -> ()>
-// CHECK-LABEL: @select_fp32
-func.func @select_fp32(%arg0: tensor<1x1x5x5xi1>, %arg1: tensor<1x12x5x5xf32>, %arg2: tensor) -> tensor<1x12x5x5xf32> {
- // CHECK: linalg.generic
- %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor) -> tensor<1x12x5x5xf32>
- return %0 : tensor<1x12x5x5xf32>
-}
-
-// -----
-
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
diff --git a/mlir/test/Dialect/traits.mlir b/mlir/test/Dialect/traits.mlir
--- a/mlir/test/Dialect/traits.mlir
+++ b/mlir/test/Dialect/traits.mlir
@@ -111,9 +111,18 @@
// -----
-func.func @broadcast_tensor_tensor_tensor(%arg0: tensor, %arg1: tensor<*xi32>) -> tensor {
- %0 = "test.broadcastable"(%arg0, %arg1) : (tensor, tensor<*xi32>) -> tensor
- return %0 : tensor
+// Error for inferred dynamic dimension but existing static dimensions
+func.func @broadcast_tensor_tensor_tensor(%arg0: tensor, %arg1: tensor) -> tensor<2xi32> {
+ // expected-error @+1 {{op result type '2' not broadcast compatible with broadcasted operands's shapes '?'}}
+ %0 = "test.broadcastable"(%arg0, %arg1) : (tensor, tensor) -> tensor<2xi32>
+ return %0 : tensor<2xi32>
+}
+
+// -----
+
+func.func @broadcast_tensor_tensor_tensor(%arg0: tensor, %arg1: tensor<*xi32>) -> tensor {
+ %0 = "test.broadcastable"(%arg0, %arg1) : (tensor, tensor<*xi32>) -> tensor
+ return %0 : tensor
}
// -----
@@ -145,10 +154,19 @@
// -----
-func.func @broadcast_tensor_tensor_tensor(tensor, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32> {
-^bb0(%arg0: tensor, %arg1: tensor<7x1x5xi32>):
- %0 = "test.broadcastable"(%arg0, %arg1) : (tensor, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32>
- return %0 : tensor<8x7x6x5xi32>
+// Correct use of broadcast semantics for input dimensions
+func.func @broadcast_tensor_tensor_tensor(%arg0: tensor, %arg1: tensor<7x1x5xi32>) -> tensor {
+ %0 = "test.broadcastable"(%arg0, %arg1) : (tensor, tensor<7x1x5xi32>) -> tensor
+ return %0 : tensor
+}
+
+// -----
+
+// Incorrect attempt to use broadcast semantics for result
+func.func @broadcast_tensor_tensor_tensor(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<5xi32> {
+ // expected-error @+1 {{op result type '5' not broadcast compatible with broadcasted operands's shapes '1'}}
+ %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<5xi32>
+ return %0 : tensor<5xi32>
}
// -----