diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -32,7 +32,10 @@ //===----------------------------------------------------------------------===// // Operator: argmax //===----------------------------------------------------------------------===// -def Tosa_ArgMaxOp : Tosa_Op<"argmax", [NoSideEffect]> { +def Tosa_ArgMaxOp : Tosa_Op<"argmax", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "Perform argmax on the input."; let description = [{ @@ -173,7 +176,10 @@ //===----------------------------------------------------------------------===// // Operator: fully_connected //===----------------------------------------------------------------------===// -def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [NoSideEffect]> { +def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "Fully Connected operator"; let description = [{ @@ -199,7 +205,10 @@ //===----------------------------------------------------------------------===// // Operator: matmul //===----------------------------------------------------------------------===// -def Tosa_MatMulOp : Tosa_Op<"matmul", [NoSideEffect]> { +def Tosa_MatMulOp : Tosa_Op<"matmul", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "Matrix multiplication with bias"; let description = [{ @@ -589,8 +598,9 @@ // Operator: logical_right_shift //===----------------------------------------------------------------------===// def Tosa_LogicalRightShiftOp : Tosa_Op<"logical_right_shift", [ - DeclareOpInterfaceMethods, ResultsBroadcastableShape, - NoSideEffect]> { + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, NoSideEffect]> { let summary = "Elementwise Logical Right Shift"; let description = [{ @@ -783,7 +793,10 @@ //===----------------------------------------------------------------------===// // Operator: table //===----------------------------------------------------------------------===// -def Tosa_TableOp : Tosa_Op<"table", [NoSideEffect]> { +def Tosa_TableOp : Tosa_Op<"table", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "Table lookup op"; let description = [{ @@ -1178,7 +1191,10 @@ //===----------------------------------------------------------------------===// // Operator: reduce_all //===----------------------------------------------------------------------===// -def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [NoSideEffect]> { +def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "Reduce All operator"; let description = [{ @@ -1198,7 +1214,10 @@ //===----------------------------------------------------------------------===// // Operator: reduce_any //===----------------------------------------------------------------------===// -def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [NoSideEffect]> { +def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "Reduce Any operator"; let description = [{ @@ -1218,7 +1237,10 @@ //===----------------------------------------------------------------------===// // Operator: reduce_max //===----------------------------------------------------------------------===// -def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [NoSideEffect]> { +def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "Reduce Max operator"; let description = [{ @@ -1238,7 +1260,10 @@ //===----------------------------------------------------------------------===// // Operator: reduce_min //===----------------------------------------------------------------------===// -def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [NoSideEffect]> { +def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "Reduce Min operator"; let description = [{ @@ -1258,7 +1283,10 @@ //===----------------------------------------------------------------------===// // Operator: reduce_prod //===----------------------------------------------------------------------===// -def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [NoSideEffect]> { +def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "Reduce Prod operator"; let description = [{ @@ -1278,7 +1306,10 @@ //===----------------------------------------------------------------------===// // Operator: reduce_sum //===----------------------------------------------------------------------===// -def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [NoSideEffect]> { +def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "Reduce Sum operator"; let description = [{ @@ -1303,7 +1334,10 @@ //===----------------------------------------------------------------------===// // Operator: concat //===----------------------------------------------------------------------===// -def Tosa_ConcatOp : Tosa_Op<"concat", [NoSideEffect]> { +def Tosa_ConcatOp : Tosa_Op<"concat", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "Concatenates tensors along one dimension."; let description = [{ @@ -1324,7 +1358,10 @@ //===----------------------------------------------------------------------===// // Operator: pad //===----------------------------------------------------------------------===// -def Tosa_PadOp : Tosa_Op<"pad", [NoSideEffect]> { +def Tosa_PadOp : Tosa_Op<"pad", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "Pads a tensor with zeros."; let description = [{ @@ -1396,7 +1433,9 @@ //===----------------------------------------------------------------------===// // Operator: slice //===----------------------------------------------------------------------===// -def Tosa_SliceOp: Tosa_Op<"slice", [NoSideEffect]> { +def Tosa_SliceOp: Tosa_Op<"slice", [ + DeclareOpInterfaceMethods, NoSideEffect]> { let summary = "Slice operator"; let description = [{ @@ -1419,7 +1458,10 @@ //===----------------------------------------------------------------------===// // Operator: tile //===----------------------------------------------------------------------===// -def Tosa_TileOp: Tosa_Op<"tile", [NoSideEffect]> { +def Tosa_TileOp: Tosa_Op<"tile", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "Tile operator"; let description = [{ @@ -1438,7 +1480,10 @@ //===----------------------------------------------------------------------===// // Operator: transpose //===----------------------------------------------------------------------===// -def Tosa_TransposeOp : Tosa_Op<"transpose", [NoSideEffect]> { +def Tosa_TransposeOp : Tosa_Op<"transpose", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "Transpose operator"; let description = [{ @@ -1463,7 +1508,10 @@ //===----------------------------------------------------------------------===// // Operator: gather //===----------------------------------------------------------------------===// -def Tosa_GatherOp : Tosa_Op<"gather", [NoSideEffect]> { +def Tosa_GatherOp : Tosa_Op<"gather", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "Gather operation,"; let description = [{ @@ -1484,7 +1532,10 @@ //===----------------------------------------------------------------------===// // Operator: scatter //===----------------------------------------------------------------------===// -def Tosa_ScatterOp : Tosa_Op<"scatter", [NoSideEffect]> { +def Tosa_ScatterOp : Tosa_Op<"scatter", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "Scatter operation,"; let description = [{ diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/RegionUtils.h" @@ -301,6 +302,260 @@ } } +LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + ShapedType inputTy = operands[0].getType().cast(); + IntegerAttr axis = attributes.get("axis").cast(); + int32_t axisVal = axis.getValue().getSExtValue(); + + if (!inputTy.hasRank()) { + inferredReturnShapes.push_back(ShapedTypeComponents()); + return success(); + } + + SmallVector outShape; + outShape.reserve(inputTy.getRank() - 1); + for (int i = 0, s = inputTy.getRank(); i < s; i++) { + if (i == axisVal) + continue; + outShape.push_back(inputTy.getDimSize(i)); + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); + return success(); +} + +LogicalResult tosa::ConcatOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + // Infer all dimension sizes by reducing based on inputs. + int32_t axis = + attributes.get("axis").cast().getValue().getSExtValue(); + llvm::SmallVector outputShape; + bool hasRankedInput = false; + for (auto operand : operands) { + ShapedType operandTy = operand.getType().cast(); + if (!operandTy.hasRank()) + continue; + + // Copy the Operand's rank. + if (!hasRankedInput) + outputShape.resize(operandTy.getRank(), -1); + + // Copy shapes until the dim is non-dynamic. + for (int i = 0, s = operandTy.getRank(); i < s; i++) { + if (i == axis || operandTy.isDynamicDim(i)) + continue; + if (outputShape[i] == -1) + outputShape[i] = operandTy.getDimSize(i); + if (outputShape[i] != operandTy.getDimSize(i)) + return failure(); + } + + hasRankedInput = true; + } + + if (!hasRankedInput) { + inferredReturnShapes.push_back(ShapedTypeComponents()); + return success(); + } + + // Determine the dimension size along the concatenation axis. + int concatDimSize = 0; + for (auto operand : operands) { + ShapedType operandTy = operand.getType().cast(); + + // We need to know the length of the concatenation axis of all inputs to + // determine the dimension size of the output shape. + if (!operandTy.hasRank() || operandTy.isDynamicDim(axis)) { + concatDimSize = -1; + break; + } + + concatDimSize += operandTy.getDimSize(axis); + } + + outputShape[axis] = concatDimSize; + + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + return success(); +} + +LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + ShapedType inputTy = operands[0].getType().cast(); + ShapedType weightTy = operands[1].getType().cast(); + ShapedType biasTy = operands[2].getType().cast(); + + // All shapes are dynamic. + SmallVector outShape; + outShape.resize(2, -1); + + if (inputTy.hasRank()) { + outShape[0] = inputTy.getDimSize(0); + } + + if (weightTy.hasRank()) { + outShape[1] = weightTy.getDimSize(0); + } + + if (biasTy.hasRank()) { + outShape[1] = outShape[1] == -1 ? biasTy.getDimSize(0) : outShape[1]; + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); + return success(); +} + +LogicalResult tosa::MatMulOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + ShapedType lhsTy = operands[0].getType().cast(); + ShapedType rhsTy = operands[1].getType().cast(); + + // All shapes are dynamic. + SmallVector outShape; + outShape.resize(3, -1); + + if (lhsTy.hasRank()) { + outShape[0] = lhsTy.getDimSize(0); + outShape[1] = lhsTy.getDimSize(1); + } + + if (rhsTy.hasRank()) { + outShape[0] = outShape[0] == -1 ? rhsTy.getDimSize(0) : outShape[0]; + outShape[2] = rhsTy.getDimSize(2); + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); + return success(); +} + +LogicalResult tosa::PadOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + ShapedType inputTy = operands[0].getType().cast(); + ShapedType paddingTy = operands[1].getType().cast(); + SmallVector outputShape; + + // If both inputs have unknown shape, we cannot determine the shape of the + // output. + if (!inputTy.hasRank() && !paddingTy.hasRank()) { + inferredReturnShapes.push_back(ShapedTypeComponents()); + return success(); + } + + // If the input rank is unknown we can info the output rank using the padding + // shape's first dim. + if (!inputTy.hasRank()) { + if (paddingTy.isDynamicDim(0)) { + inferredReturnShapes.push_back(ShapedTypeComponents()); + return success(); + } + + outputShape.resize(paddingTy.getDimSize(0), -1); + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + return success(); + } + + DenseIntElementsAttr paddings; + // If the paddings value is not a constant, all dimensions must be dynamic. + if (!matchPattern(operands[1], m_Constant(&paddings))) { + outputShape.resize(inputTy.getRank(), -1); + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + return success(); + } + + SmallVector paddingValues; + for (auto val : paddings) { + paddingValues.push_back(val.getSExtValue()); + } + + outputShape.reserve(inputTy.getRank()); + for (int i = 0, s = inputTy.getRank(); i < s; i++) { + if (inputTy.isDynamicDim(i)) { + outputShape.push_back(-1); + continue; + } + + outputShape.push_back(inputTy.getDimSize(i) + paddingValues[i * 2] + + paddingValues[i * 2 + 1]); + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + return success(); +} + +LogicalResult tosa::SliceOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + auto sizes = attributes.get("size").cast().getValue(); + SmallVector outputShape; + outputShape.reserve(sizes.size()); + for (auto val : sizes) { + outputShape.push_back(val.cast().getValue().getSExtValue()); + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + return success(); +} + +LogicalResult tosa::TableOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + ShapedType inputTy = operands[0].getType().cast(); + + if (!inputTy.hasRank()) { + inferredReturnShapes.push_back(ShapedTypeComponents()); + return success(); + } + + inferredReturnShapes.push_back(inputTy.getShape()); + return success(); +} + +LogicalResult tosa::TileOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + auto multiples = attributes.get("multiples").cast().getValue(); + ShapedType inputTy = operands[0].getType().cast(); + SmallVector outputShape; + if (!inputTy.hasRank()) { + outputShape.resize(multiples.size(), -1); + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + return success(); + } + + // We need the multiple values to determine the output shape. + SmallVector multipleValues; + multipleValues.reserve(multiples.size()); + for (auto val : multiples) { + multipleValues.push_back(val.cast().getValue().getSExtValue()); + } + + // Any non dynamic dimension can be multiplied to a known size. + outputShape.reserve(multiples.size()); + for (int i = 0, s = inputTy.getRank(); i < s; i++) { + int dim = inputTy.getDimSize(i); + if (dim != -1) + dim *= multipleValues[i]; + outputShape.push_back(dim); + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + return success(); +} + LogicalResult tosa::ReshapeOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, @@ -339,6 +594,163 @@ return success(); } +LogicalResult tosa::TransposeOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + ShapedType inputTy = operands[0].getType().cast(); + ShapedType permsTy = operands[1].getType().cast(); + + // If input rank and permutation length is unknown, the output rank is + // unknown. + if (!inputTy.hasRank() && (!permsTy.hasRank() || permsTy.isDynamicDim(0))) { + inferredReturnShapes.push_back(ShapedTypeComponents()); + return success(); + } + + // Without the input dims we cannot determine the output dim sizes but we + // can determine the output rank. + SmallVector outputShape; + if (!inputTy.hasRank()) { + outputShape.resize(permsTy.getDimSize(0), -1); + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + return success(); + } + + // Rank-0 means no permutations matter. + if (inputTy.getRank() == 0) { + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + return success(); + } + + // Check whether the input dimensions are all the same. + bool allTheSame = true; + for (int i = 1, s = inputTy.getRank(); i < s; i++) { + if (inputTy.getDimSize(0) != inputTy.getDimSize(i)) { + allTheSame = false; + break; + } + } + + // If all of the input dimensions are the same we don't care about the + // permutation. + if (allTheSame) { + outputShape.resize(inputTy.getRank(), inputTy.getDimSize(0)); + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + return success(); + } + + DenseIntElementsAttr perms; + outputShape.resize(inputTy.getRank(), -1); + // If the permuations are a constant we can directly determine the output + // shape. + if (matchPattern(operands[1], m_Constant(&perms))) { + llvm::SmallVector permValues; + for (auto val : perms) { + permValues.push_back(val.getSExtValue()); + } + + outputShape.reserve(inputTy.getRank()); + for (int i = 0, s = inputTy.getRank(); i < s; i++) { + outputShape[i] = inputTy.getDimSize(permValues[i]); + } + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + return success(); +} + +LogicalResult tosa::GatherOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + llvm::SmallVector outputShape; + outputShape.resize(3, -1); + + if (auto ty = operands[0].getType().dyn_cast()) { + outputShape[0] = ty.getDimSize(0); + outputShape[2] = ty.getDimSize(2); + } + + if (auto ty = operands[1].getType().dyn_cast()) { + if (outputShape[0] == -1) + outputShape[0] = ty.getDimSize(0); + if (outputShape[1] == -1) + outputShape[1] = ty.getDimSize(1); + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + return success(); +} + +LogicalResult tosa::ScatterOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + llvm::SmallVector outputShape; + outputShape.resize(3, -1); + + if (auto ty = operands[0].getType().dyn_cast()) { + outputShape[0] = ty.getDimSize(0); + outputShape[1] = ty.getDimSize(1); + outputShape[2] = ty.getDimSize(2); + } + + if (auto ty = operands[1].getType().dyn_cast()) { + if (outputShape[0] == -1) + outputShape[0] = ty.getDimSize(0); + } + + if (auto ty = operands[2].getType().dyn_cast()) { + if (outputShape[0] == -1) + outputShape[0] = ty.getDimSize(0); + if (outputShape[2] == -1) + outputShape[2] = ty.getDimSize(2); + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + return success(); +} + +static LogicalResult ReduceInferReturnTypes( + Value operand, IntegerAttr axis, + SmallVectorImpl &inferredReturnShapes) { + auto operandTy = operand.getType().cast(); + if (!operandTy.hasRank()) { + inferredReturnShapes.push_back(ShapedTypeComponents()); + return success(); + } + + int64_t axisVal = axis.getValue().getSExtValue(); + SmallVector outputShape; + outputShape.reserve(operandTy.getRank()); + for (auto dim : operandTy.getShape()) { + outputShape.push_back(dim); + } + + outputShape[axisVal] = 1; + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + return success(); +} + +#define REDUCE_SHAPE_INFER(OP) \ + LogicalResult OP::inferReturnTypeComponents( \ + MLIRContext *context, ::llvm::Optional location, \ + ValueRange operands, DictionaryAttr attributes, RegionRange regions, \ + SmallVectorImpl &inferredReturnShapes) { \ + return ReduceInferReturnTypes(operands[0], \ + attributes.get("axis").cast(), \ + inferredReturnShapes); \ + } + +REDUCE_SHAPE_INFER(tosa::ReduceAllOp) +REDUCE_SHAPE_INFER(tosa::ReduceAnyOp) +REDUCE_SHAPE_INFER(tosa::ReduceMaxOp) +REDUCE_SHAPE_INFER(tosa::ReduceMinOp) +REDUCE_SHAPE_INFER(tosa::ReduceProdOp) +REDUCE_SHAPE_INFER(tosa::ReduceSumOp) +#undef REDUCE_SHAPE_INFER + static LogicalResult resolveBroadcastShape(ValueRange operands, SmallVector &outShape) { int64_t outRank = 0; diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -0,0 +1,662 @@ +// RUN: mlir-opt --split-input-file --tosa-infer-shapes %s | FileCheck %s + +// CHECK-LABEL: @test_return +func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> { + // CHECK: [[LOG:%.+]] = "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: tensor.cast [[LOG]] : tensor<4xf32> to tensor<*xf32> + %0 = "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: @test_multiple +func @test_multiple(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>, %arg2 : tensor) -> tensor<*xf32> { + // CHECK: [[ADD:%.+]] = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + + // CHECK: [[LOG:%.+]] = "tosa.log"(%0) : (tensor<4xf32>) -> tensor<4xf32> + %1 = "tosa.log"(%0) : (tensor<*xf32>) -> tensor<*xf32> + + // CHECK: [[SUB:%.+]] = "tosa.sub"(%0, %arg2) : (tensor<4xf32>, tensor) -> tensor<4xf32> + %2 = "tosa.sub"(%0, %arg2) : (tensor<*xf32>, tensor) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: @test_unary_f32 +func @test_unary_f32(%arg0 : tensor<4xf32>) -> () { + // CHECK: "tosa.abs"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %0 = "tosa.abs"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.ceil"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %1 = "tosa.ceil"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.clamp"(%arg0) {{.+}} : (tensor<4xf32>) -> tensor<4xf32> + %2 = "tosa.clamp"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.exp"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %3 = "tosa.exp"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.floor"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %4 = "tosa.floor"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %5 = "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.negate"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %6 = "tosa.negate"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.reciprocal"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %7 = "tosa.reciprocal"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.reluN"(%arg0) {{.+}} : (tensor<4xf32>) -> tensor<4xf32> + %8 = "tosa.reluN"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<4xf32>) -> tensor<4xf32> + %9 = "tosa.reverse"(%arg0) { axis = 0 : i64 } : (tensor<4xf32>) -> tensor + + // CHECK: "tosa.rsqrt"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %10 = "tosa.rsqrt"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.tanh"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %11 = "tosa.tanh"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.sigmoid"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %12 = "tosa.sigmoid"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + return +} + +// ----- + +// CHECK-LABEL: @test_unary_i32 +func @test_unary_i32(%arg0 : tensor<4xi32>) -> () { + // CHECK: "tosa.abs"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> + %0 = "tosa.abs"(%arg0) : (tensor<4xi32>) -> tensor<*xi32> + + // CHECK: "tosa.bitwise_not"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> + %1 = "tosa.bitwise_not"(%arg0) : (tensor<4xi32>) -> tensor<*xi32> + + // CHECK: "tosa.clamp"(%arg0) {{.+}} : (tensor<4xi32>) -> tensor<4xi32> + %2 = "tosa.clamp"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xi32>) -> tensor<*xi32> + + // CHECK: "tosa.clz"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> + %3 = "tosa.clz"(%arg0) : (tensor<4xi32>) -> tensor<*xi32> + + // CHECK: "tosa.negate"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> + %4 = "tosa.negate"(%arg0) : (tensor<4xi32>) -> tensor<*xi32> + + // CHECK: "tosa.reluN"(%arg0) {{.+}} : (tensor<4xi32>) -> tensor<4xi32> + %5 = "tosa.reluN"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xi32>) -> tensor<*xi32> + + // CHECK: "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<4xi32>) -> tensor<4xi32> + %6 = "tosa.reverse"(%arg0) { axis = 0 : i64 } : (tensor<4xi32>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @test_unary_i1 +func @test_unary_i1(%arg0 : tensor<4xi1>) -> () { + // CHECK: "tosa.logical_not"(%arg0) : (tensor<4xi1>) -> tensor<4xi1> + %0 = "tosa.logical_not"(%arg0) : (tensor<4xi1>) -> tensor<*xi1> + return +} + +// ----- + +// CHECK-LABEL: @test_binary_scalar_f32 +func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor) -> () { + // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> + %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> + + // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> + %1 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> + + // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> + %2 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> + + // CHECK: "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xf32>, tensor) -> tensor<4xf32> + %3 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xf32>, tensor) -> tensor<*xf32> + + // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> + %4 = "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> + + // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> + %5 = "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> + + // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xi1> + %6 = "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xi1> + + // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xi1> + %7 = "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xi1> + + // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xi1> + %8 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xi1> + + return +} + +// ----- + +// CHECK-LABEL: @test_binary_broadcast_f32 +func @test_binary_broadcast_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>) -> () { + // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + + // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %1 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + + // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %2 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + + // CHECK: "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %3 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 } : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + + // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %4 = "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + + // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %5 = "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + + // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> + %6 = "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + + // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> + %7 = "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + + // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> + %8 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + + return +} + +// ----- + +// CHECK-LABEL: @test_binary_i32 +func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor) -> () { + // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> + %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.bitwise_and"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> + %1 = "tosa.bitwise_and"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.bitwise_or"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> + %2 = "tosa.bitwise_or"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.bitwise_xor"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> + %3 = "tosa.bitwise_xor"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi1> + %4 = "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi1> + + // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi1> + %5 = "tosa.greater"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi1> + + // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi1> + %6 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi1> + + // CHECK: "tosa.logical_left_shift"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> + %7 = "tosa.logical_left_shift"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.logical_right_shift"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> + %8 = "tosa.logical_right_shift"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> + %9 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> + %10 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> + %11 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> + %12 = "tosa.pow"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> + %13 = "tosa.sub"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> + + return +} + +// ----- + +// CHECK-LABEL: @test_binary_i1 +func @test_binary_i1(%arg0 : tensor<4xi1>, %arg1 : tensor) -> () { + // CHECK "tosa.logical_and"(%arg0, %arg1) : (tensor<4xi1>, tensor) -> tensor<4xi1> + %0 = "tosa.logical_and"(%arg0, %arg1): (tensor<4xi1>, tensor) -> tensor<*xi1> + + // CHECK "tosa.logical_or"(%arg0, %arg1) : (tensor<4xi1>, tensor) -> tensor<4xi1> + %1 = "tosa.logical_or"(%arg0, %arg1): (tensor<4xi1>, tensor) -> tensor<*xi1> + + // CHECK "tosa.logical_xor"(%arg0, %arg1) : (tensor<4xi1>, tensor) -> tensor<*4i1> + %2 = "tosa.logical_xor"(%arg0, %arg1): (tensor<4xi1>, tensor) -> tensor<*xi1> + + return +} + +// ----- + +// CHECK-LABEL: @test_select_i32 +func @test_select_i32(%arg0 : tensor<4xi1>, %arg1 : tensor, %arg2 : tensor<4xi32>) -> () { + // CHECK: "tosa.select"(%arg0, %arg1, %arg2) : (tensor<4xi1>, tensor, tensor<4xi32>) -> tensor<4xi32> + %0 = "tosa.select"(%arg0, %arg1, %arg2): (tensor<4xi1>, tensor, tensor<4xi32>) -> tensor<*xi32> + + return +} + +// ----- + +// CHECK-LABEL: @test_static_argmax +func @test_static_argmax(%arg0 : tensor<2x3xi32>) -> () { + // CHECK: "tosa.argmax"(%arg0) {axis = 0 : i64} : (tensor<2x3xi32>) -> tensor<3xi32> + %0 = "tosa.argmax"(%arg0) {axis = 0 : i64} : (tensor<2x3xi32>) -> tensor + + // CHECK: "tosa.argmax"(%arg0) {axis = 1 : i64} : (tensor<2x3xi32>) -> tensor<2xi32> + %1 = "tosa.argmax"(%arg0) {axis = 1 : i64} : (tensor<2x3xi32>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @test_dynamic_argmax +func @test_dynamic_argmax(%arg0 : tensor<2x?xi32>) -> () { + // CHECK: "tosa.argmax"(%arg0) {axis = 0 : i64} : (tensor<2x?xi32>) -> tensor + %0 = "tosa.argmax"(%arg0) {axis = 0 : i64} : (tensor<2x?xi32>) -> tensor + + // CHECK: "tosa.argmax"(%arg0) {axis = 1 : i64} : (tensor<2x?xi32>) -> tensor<2xi32> + %1 = "tosa.argmax"(%arg0) {axis = 1 : i64} : (tensor<2x?xi32>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @test_static_fully_connected +func @test_static_fully_connected(%arg0 : tensor<3x4xf32>, %arg1 : tensor<5x4xf32>, %arg2 : tensor<5xf32>) -> () { + // CHECK: "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<3x4xf32>, tensor<5x4xf32>, tensor<5xf32>) -> tensor<3x5xf32> + %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<3x4xf32>, tensor<5x4xf32>, tensor<5xf32>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @test_static_input_fully_connected +func @test_static_input_fully_connected(%arg0 : tensor<3x4xf32>, %arg1 : tensor, %arg2 : tensor) -> () { + // CHECK: "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<3x4xf32>, tensor, tensor) -> tensor<3x?xf32> + %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<3x4xf32>, tensor, tensor) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @test_static_weight_fully_connected +func @test_static_weight_fully_connected(%arg0 : tensor, %arg1 : tensor<5x4xf32>, %arg2 : tensor) -> () { + // CHECK: "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor, tensor<5x4xf32>, tensor) -> tensor + %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor, tensor<5x4xf32>, tensor) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @test_static_bias_fully_connected +func @test_static_bias_fully_connected(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor<5xf32>) -> () { + // CHECK: "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor<5xf32>) -> tensor + %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor<5xf32>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @test_static_out_fully_connected +func @test_static_out_fully_connected(%arg0 : tensor<3x?xf32>, %arg1 : tensor, %arg2 : tensor<5xf32>) -> () { + // CHECK: "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<3x?xf32>, tensor, tensor<5xf32>) -> tensor<3x5xf32> + %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<3x?xf32>, tensor, tensor<5xf32>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @test_static_matmul +func @test_static_matmul(%arg0 : tensor<2x3x4xi32>, %arg1 : tensor<2x4x5xi32>) -> () { + // CHECK: "tosa.matmul"(%arg0, %arg1) : (tensor<2x3x4xi32>, tensor<2x4x5xi32>) -> tensor<2x3x5xi32> + %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<2x3x4xi32>, tensor<2x4x5xi32>) -> tensor + + return +} + +// ----- + +// CHECK-LABEL: @test_dynamic_lhs_matmul +func @test_dynamic_lhs_matmul(%arg0 : tensor, %arg1 : tensor<2x4x5xi32>) -> () { + // CHECK: "tosa.matmul"(%arg0, %arg1) : (tensor, tensor<2x4x5xi32>) -> tensor<2x?x5xi32> + %0 = "tosa.matmul"(%arg0, %arg1) : (tensor, tensor<2x4x5xi32>) -> tensor + + return +} + +// ----- + +// CHECK-LABEL: @test_dynamic_rhs_matmul +func @test_dynamic_rhs_matmul(%arg0 : tensor<2x3x4xi32>, %arg1 : tensor) -> () { + // CHECK: "tosa.matmul"(%arg0, %arg1) : (tensor<2x3x4xi32>, tensor) -> tensor<2x3x?xi32> + %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<2x3x4xi32>, tensor) -> tensor + + return +} + +// ----- + +// CHECK-LABEL: @test_dynamic_mixed_matmul +func @test_dynamic_mixed_matmul(%arg0 : tensor, %arg1 : tensor) -> () { + // CHECK: "tosa.matmul"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "tosa.matmul"(%arg0, %arg1) : (tensor, tensor) -> tensor + + return +} + +// ----- + +// CHECK-LABLE: @test_table_static +func @test_table_static(%arg0 : tensor<4x5xi16>, %arg1 : tensor<513xi16>) -> () { + // CHECK:"tosa.table"(%arg0, %arg1) : (tensor<4x5xi16>, tensor<513xi16>) -> tensor<4x5xi16> + %0 = "tosa.table"(%arg0, %arg1) : (tensor<4x5xi16>, tensor<513xi16>) -> tensor + return +} + +// ----- + +// CHECK-LABLE: @test_table_dynamic +func @test_table_dynamic(%arg0 : tensor<4x?xi16>, %arg1 : tensor<513xi16>) -> () { + // CHECK:"tosa.table"(%arg0, %arg1) : (tensor<4x?xi16>, tensor<513xi16>) -> tensor<4x?xi16> + %0 = "tosa.table"(%arg0, %arg1) : (tensor<4x?xi16>, tensor<513xi16>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @test_static_reshape +func @test_static_reshape(%arg0 : tensor<4x4xi32>) -> () { + // CHECK: "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x4xi32>) -> tensor<16xi32> + %0 = "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x4xi32>) -> tensor + + // CHECK: "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x4xi32>) -> tensor<16xi32> + %1 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x4xi32>) -> tensor + + // CHECK: "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x4xi32>) -> tensor<2x8xi32> + %2 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x4xi32>) -> tensor + + return +} +// ----- + +// CHECK-LABEL: @test_dynamic_reshape +func @test_dynamic_reshape(%arg0 : tensor<4x?xi32>) -> () { + // CHECK: %0 = "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x?xi32>) -> tensor<16xi32> + %0 = "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x?xi32>) -> tensor + + // CHECK: %1 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x?xi32>) -> tensor + %1 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x?xi32>) -> tensor + + // CHECK: %2 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x?xi32>) -> tensor<2x?xi32> + %2 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x?xi32>) -> tensor + + return +} + +// ----- + +// CHECK: @test_reduce_binary +func @test_reduce_binary(%arg0 : tensor<2x3x?x?xi1>) -> () { + // CHECK: "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<2x3x?x?xi1>) -> tensor<1x3x?x?xi1> + %0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<2x3x?x?xi1>) -> tensor + + // CHECK: "tosa.reduce_all"(%arg0) {axis = 1 : i64} : (tensor<2x3x?x?xi1>) -> tensor<2x1x?x?xi1> + %1 = "tosa.reduce_all"(%arg0) {axis = 1 : i64} : (tensor<2x3x?x?xi1>) -> tensor + + // CHECK: "tosa.reduce_all"(%arg0) {axis = 2 : i64} : (tensor<2x3x?x?xi1>) -> tensor<2x3x1x?xi1> + %2 = "tosa.reduce_all"(%arg0) {axis = 2 : i64} : (tensor<2x3x?x?xi1>) -> tensor + + // CHECK: "tosa.reduce_all"(%arg0) {axis = 3 : i64} : (tensor<2x3x?x?xi1>) -> tensor<2x3x?x1xi1> + %3 = "tosa.reduce_all"(%arg0) {axis = 3 : i64} : (tensor<2x3x?x?xi1>) -> tensor + + // CHECK: "tosa.reduce_any"(%arg0) {axis = 0 : i64} : (tensor<2x3x?x?xi1>) -> tensor<1x3x?x?xi1> + %4 = "tosa.reduce_any"(%arg0) {axis = 0 : i64} : (tensor<2x3x?x?xi1>) -> tensor + + return +} + +// ----- + +// CHECK: @test_reduce_float +func @test_reduce_float(%arg0 : tensor<2x3x?x?xf32>) -> () { + // CHECK: "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<2x3x?x?xf32>) -> tensor<1x3x?x?xf32> + %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<2x3x?x?xf32>) -> tensor + + // CHECK: "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<2x3x?x?xf32>) -> tensor<2x1x?x?xf32> + %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<2x3x?x?xf32>) -> tensor + + // CHECK: "tosa.reduce_sum"(%arg0) {axis = 2 : i64} : (tensor<2x3x?x?xf32>) -> tensor<2x3x1x?xf32> + %2 = "tosa.reduce_sum"(%arg0) {axis = 2 : i64} : (tensor<2x3x?x?xf32>) -> tensor + + // CHECK: "tosa.reduce_sum"(%arg0) {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor<2x3x?x1xf32> + %3 = "tosa.reduce_sum"(%arg0) {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor + + // CHECK: "tosa.reduce_max"(%arg0) {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor<2x3x?x1xf32> + %4 = "tosa.reduce_max"(%arg0) {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor + + // CHECK: "tosa.reduce_min"(%arg0) {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor<2x3x?x1xf32> + %5 = "tosa.reduce_min"(%arg0) {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor + + // CHECK: "tosa.reduce_prod"(%arg0) {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor<2x3x?x1xf32> + %6 = "tosa.reduce_prod"(%arg0) {axis = 3 : i64} : (tensor<2x3x?x?xf32>) -> tensor + + return +} + +// ----- + +// CHECK-LABEL: @test_concat +func @test_concat(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> () { + // CHECK: "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<3x2xf32> + %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor + + return +} + +// ----- + +// CHECK-LABEL: @test_concat_dynamic +func @test_concat_dynamic(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x?xf32>) -> () { + // CHECK: "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<3x2xf32> + %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor + + return +} + +// ----- + +// CHECK-LABEL: @test_concat_dynamic_axis +func @test_concat_dynamic_axis(%arg0 : tensor, %arg1 : tensor<2x2xf32>) -> () { + // CHECK: "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor, tensor<2x2xf32>) -> tensor + %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor, tensor<2x2xf32>) -> tensor + + return +} + +// ----- + +// CHECK-LABEL: @test_concat_axis_1 +func @test_concat_axis_1(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> () { + // CHECK: "tosa.concat"(%arg0, %arg1) {axis = 1 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<2x3xf32> + %0 = "tosa.concat"(%arg0, %arg1) {axis = 1 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor + + return +} + +// ----- + +// CHECK-LABEL: @test_concat_failure +func @test_concat_failure(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> () { + // CHECK: "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor + %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor + + return +} + +// ----- + +// CHECK-LABEL: @test_padding_no_const +func @test_padding_no_const(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xi32>) -> () { + // CHECK: "tosa.pad"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor + %0 = "tosa.pad"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL:@test_padding_dynamic_input +func @test_padding_dynamic_input(%arg0 : tensor<1x?xf32>) -> () { + %0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> + // CHECK: "tosa.pad"(%arg0, %cst) : (tensor<1x?xf32>, tensor<2x2xi32>) -> tensor<4x?xf32> + %1 = "tosa.pad"(%arg0, %0) : (tensor<1x?xf32>, tensor<2x2xi32>) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @test_padding_simple +func @test_padding_simple(%arg0 : tensor<1x2xf32>) -> () { + %0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> + // CHECK: "tosa.pad"(%arg0, %cst) : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<4x9xf32> + %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @test_slice +func @test_slice(%arg0 : tensor) -> () { + // CHECK: "tosa.slice"(%arg0) {size = [2], start = [1]} : (tensor) -> tensor<2xi32> + %0 = "tosa.slice"(%arg0) { size = [2], start = [1] } : (tensor) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @test_tile +func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () { + // CHECK: "tosa.tile"(%arg0) {multiples = [2, 1, 5]} : (tensor<2x3x?xi32>) -> tensor<4x3x?xi32> + %0 = "tosa.tile"(%arg0) {multiples = [2, 1, 5]} : (tensor<2x3x?xi32>) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @test_transpose_same +func @test_transpose_same(%arg0 : tensor<4x4x4xi32>, %arg1 : tensor<3xi32>) -> () { + // CHECK: "tosa.transpose"(%arg0, %arg1) : (tensor<4x4x4xi32>, tensor<3xi32>) -> tensor<4x4x4xi32> + %0 = "tosa.transpose"(%arg0, %arg1) : (tensor<4x4x4xi32>, tensor<3xi32>) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @test_transpose_perm_unknown +func @test_transpose_perm_unknown(%arg0 : tensor<4x4x5xi32>, %arg1 : tensor<3xi32>) -> () { + // CHECK: "tosa.transpose"(%arg0, %arg1) : (tensor<4x4x5xi32>, tensor<3xi32>) -> tensor + %0 = "tosa.transpose"(%arg0, %arg1) : (tensor<4x4x5xi32>, tensor<3xi32>) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @test_transpose_static +func @test_transpose_static(%arg0 : tensor<3x4x5xi32>) -> () { + %0 = constant dense<[2, 1, 0]> : tensor<3xi32> + // CHECK: "tosa.transpose"(%arg0, %cst) : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<5x4x3xi32> + %1 = "tosa.transpose"(%arg0, %0) : (tensor<3x4x5xi32>, tensor<3xi32>) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @gather_static +func @gather_static(%arg0 : tensor<3x4x5xi32>, %arg1 : tensor<3x6xi32>) { + // CHECK: "tosa.gather"(%arg0, %arg1) : (tensor<3x4x5xi32>, tensor<3x6xi32>) -> tensor<3x6x5xi32> + %0 = "tosa.gather"(%arg0, %arg1) : (tensor<3x4x5xi32>, tensor<3x6xi32>) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @gather_dynamic_values +func @gather_dynamic_values(%arg0 : tensor, %arg1 : tensor<3x6xi32>) { + // CHECK: "tosa.gather"(%arg0, %arg1) : (tensor, tensor<3x6xi32>) -> tensor<3x6x?xi32> + %0 = "tosa.gather"(%arg0, %arg1) : (tensor, tensor<3x6xi32>) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @gather_dynamic_indices +func @gather_dynamic_indices(%arg0 : tensor<3x4x5xi32>, %arg1 : tensor) { + // CHECK: "tosa.gather"(%arg0, %arg1) : (tensor<3x4x5xi32>, tensor) -> tensor<3x?x5xi32> + %0 = "tosa.gather"(%arg0, %arg1) : (tensor<3x4x5xi32>, tensor) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @gather_minimum_info +func @gather_minimum_info(%arg0 : tensor<3x?x5xi32>, %arg1 : tensor) { + // CHECK: "tosa.gather"(%arg0, %arg1) : (tensor<3x?x5xi32>, tensor) -> tensor<3x6x5xi32> + %0 = "tosa.gather"(%arg0, %arg1) : (tensor<3x?x5xi32>, tensor) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @scatter_static +func @scatter_static(%arg0 : tensor<3x4x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) { + // CHECK: "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x4x5xi32> + %0 = "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @scatter_static_values +func @scatter_static_values(%arg0 : tensor<3x4x5xi32>, %arg1 : tensor, %arg2 : tensor) { + // CHECK: "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<3x4x5xi32>, tensor, tensor) -> tensor<3x4x5xi32> + %0 = "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<3x4x5xi32>, tensor, tensor) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @scatter_static_indices +func @scatter_static_indices(%arg0 : tensor, %arg1 : tensor<3x6xi32>, %arg2 : tensor) { + // CHECK: "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor, tensor<3x6xi32>, tensor) -> tensor<3x?x?xi32> + %0 = "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor, tensor<3x6xi32>, tensor) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @scatter_static_input +func @scatter_static_input(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor<3x6x5xi32>) { + // CHECK: "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor<3x6x5xi32>) -> tensor<3x?x5xi32> + %0 = "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor<3x6x5xi32>) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @scatter_minimum_static +func @scatter_minimum_static(%arg0 : tensor, %arg1 : tensor<3x?xi32>, %arg2 : tensor) { + // CHECK: "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor, tensor<3x?xi32>, tensor) -> tensor<3x4x5xi32> + %0 = "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor, tensor<3x?xi32>, tensor) -> (tensor) + return +} diff --git a/mlir/test/Dialect/Tosa/tosa_infer_shapes.mlir b/mlir/test/Dialect/Tosa/tosa_infer_shapes.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Tosa/tosa_infer_shapes.mlir +++ /dev/null @@ -1,278 +0,0 @@ -// RUN: mlir-opt --split-input-file --tosa-infer-shapes %s | FileCheck %s - -// CHECK-LABEL: @test_return -func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> { - // CHECK: [[LOG:%.+]] = "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: tensor.cast [[LOG]] : tensor<4xf32> to tensor<*xf32> - %0 = "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> - return %0 : tensor<*xf32> -} - -// ----- - -// CHECK-LABEL: @test_multiple -func @test_multiple(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>, %arg2 : tensor) -> tensor<*xf32> { - // CHECK: [[ADD:%.+]] = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - - // CHECK: [[LOG:%.+]] = "tosa.log"(%0) : (tensor<4xf32>) -> tensor<4xf32> - %1 = "tosa.log"(%0) : (tensor<*xf32>) -> tensor<*xf32> - - // CHECK: [[SUB:%.+]] = "tosa.sub"(%0, %arg2) : (tensor<4xf32>, tensor) -> tensor<4xf32> - %2 = "tosa.sub"(%0, %arg2) : (tensor<*xf32>, tensor) -> tensor<*xf32> - return %0 : tensor<*xf32> -} - -// ----- - -// CHECK-LABEL: @test_unary_f32 -func @test_unary_f32(%arg0 : tensor<4xf32>) -> () { - // CHECK: "tosa.abs"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> - %0 = "tosa.abs"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> - - // CHECK: "tosa.ceil"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> - %1 = "tosa.ceil"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> - - // CHECK: "tosa.clamp"(%arg0) {{.+}} : (tensor<4xf32>) -> tensor<4xf32> - %2 = "tosa.clamp"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xf32>) -> tensor<*xf32> - - // CHECK: "tosa.exp"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> - %3 = "tosa.exp"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> - - // CHECK: "tosa.floor"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> - %4 = "tosa.floor"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> - - // CHECK: "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> - %5 = "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> - - // CHECK: "tosa.negate"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> - %6 = "tosa.negate"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> - - // CHECK: "tosa.reciprocal"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> - %7 = "tosa.reciprocal"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> - - // CHECK: "tosa.reluN"(%arg0) {{.+}} : (tensor<4xf32>) -> tensor<4xf32> - %8 = "tosa.reluN"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xf32>) -> tensor<*xf32> - - // CHECK: "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<4xf32>) -> tensor<4xf32> - %9 = "tosa.reverse"(%arg0) { axis = 0 : i64 } : (tensor<4xf32>) -> tensor - - // CHECK: "tosa.rsqrt"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> - %10 = "tosa.rsqrt"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> - - // CHECK: "tosa.tanh"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> - %11 = "tosa.tanh"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> - - // CHECK: "tosa.sigmoid"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> - %12 = "tosa.sigmoid"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> - return -} - -// ----- - -// CHECK-LABEL: @test_unary_i32 -func @test_unary_i32(%arg0 : tensor<4xi32>) -> () { - // CHECK: "tosa.abs"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> - %0 = "tosa.abs"(%arg0) : (tensor<4xi32>) -> tensor<*xi32> - - // CHECK: "tosa.bitwise_not"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> - %1 = "tosa.bitwise_not"(%arg0) : (tensor<4xi32>) -> tensor<*xi32> - - // CHECK: "tosa.clamp"(%arg0) {{.+}} : (tensor<4xi32>) -> tensor<4xi32> - %2 = "tosa.clamp"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xi32>) -> tensor<*xi32> - - // CHECK: "tosa.clz"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> - %3 = "tosa.clz"(%arg0) : (tensor<4xi32>) -> tensor<*xi32> - - // CHECK: "tosa.negate"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> - %4 = "tosa.negate"(%arg0) : (tensor<4xi32>) -> tensor<*xi32> - - // CHECK: "tosa.reluN"(%arg0) {{.+}} : (tensor<4xi32>) -> tensor<4xi32> - %5 = "tosa.reluN"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xi32>) -> tensor<*xi32> - - // CHECK: "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<4xi32>) -> tensor<4xi32> - %6 = "tosa.reverse"(%arg0) { axis = 0 : i64 } : (tensor<4xi32>) -> tensor - return -} - -// ----- - -// CHECK-LABEL: @test_unary_i1 -func @test_unary_i1(%arg0 : tensor<4xi1>) -> () { - // CHECK: "tosa.logical_not"(%arg0) : (tensor<4xi1>) -> tensor<4xi1> - %0 = "tosa.logical_not"(%arg0) : (tensor<4xi1>) -> tensor<*xi1> - return -} - -// ----- - -// CHECK-LABEL: @test_binary_scalar_f32 -func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor) -> () { - // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> - %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> - - // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> - %1 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> - - // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> - %2 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> - - // CHECK: "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xf32>, tensor) -> tensor<4xf32> - %3 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xf32>, tensor) -> tensor<*xf32> - - // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> - %4 = "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> - - // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> - %5 = "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> - - // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xi1> - %6 = "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xi1> - - // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xi1> - %7 = "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xi1> - - // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xi1> - %8 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xi1> - - return -} - -// ----- - -// CHECK-LABEL: @test_binary_broadcast_f32 -func @test_binary_broadcast_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>) -> () { - // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - - // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %1 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - - // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %2 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - - // CHECK: "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %3 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 } : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - - // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %4 = "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - - // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %5 = "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - - // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %6 = "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> - - // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %7 = "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> - - // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %8 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> - - return -} - -// ----- - -// CHECK-LABEL: @test_binary_i32 -func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor) -> () { - // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> - %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> - - // CHECK: "tosa.bitwise_and"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> - %1 = "tosa.bitwise_and"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> - - // CHECK: "tosa.bitwise_or"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> - %2 = "tosa.bitwise_or"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> - - // CHECK: "tosa.bitwise_xor"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> - %3 = "tosa.bitwise_xor"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> - - // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi1> - %4 = "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi1> - - // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi1> - %5 = "tosa.greater"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi1> - - // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi1> - %6 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi1> - - // CHECK: "tosa.logical_left_shift"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> - %7 = "tosa.logical_left_shift"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> - - // CHECK: "tosa.logical_right_shift"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> - %8 = "tosa.logical_right_shift"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> - - // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> - %9 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> - - // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> - %10 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> - - // CHECK: "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> - %11 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> - - // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> - %12 = "tosa.pow"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> - - // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> - %13 = "tosa.sub"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> - - return -} - -// ----- - -// CHECK-LABEL: @test_binary_i1 -func @test_binary_i1(%arg0 : tensor<4xi1>, %arg1 : tensor) -> () { - // CHECK "tosa.logical_and"(%arg0, %arg1) : (tensor<4xi1>, tensor) -> tensor<4xi1> - %0 = "tosa.logical_and"(%arg0, %arg1): (tensor<4xi1>, tensor) -> tensor<*xi1> - - // CHECK "tosa.logical_or"(%arg0, %arg1) : (tensor<4xi1>, tensor) -> tensor<4xi1> - %1 = "tosa.logical_or"(%arg0, %arg1): (tensor<4xi1>, tensor) -> tensor<*xi1> - - // CHECK "tosa.logical_xor"(%arg0, %arg1) : (tensor<4xi1>, tensor) -> tensor<*4i1> - %2 = "tosa.logical_xor"(%arg0, %arg1): (tensor<4xi1>, tensor) -> tensor<*xi1> - - return -} - -// ----- - -// CHECK-LABEL: @test_select_i32 -func @test_select_i32(%arg0 : tensor<4xi1>, %arg1 : tensor, %arg2 : tensor<4xi32>) -> () { - // CHECK: "tosa.select"(%arg0, %arg1, %arg2) : (tensor<4xi1>, tensor, tensor<4xi32>) -> tensor<4xi32> - %0 = "tosa.select"(%arg0, %arg1, %arg2): (tensor<4xi1>, tensor, tensor<4xi32>) -> tensor<*xi32> - - return -} - -// ----- - -func @test_static_reshape(%arg0 : tensor<4x4xi32>) -> () { - // CHECK: "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x4xi32>) -> tensor<16xi32> - %0 = "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x4xi32>) -> tensor - - // CHECK: "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x4xi32>) -> tensor<16xi32> - %1 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x4xi32>) -> tensor - - // CHECK: "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x4xi32>) -> tensor<2x8xi32> - %2 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x4xi32>) -> tensor - - return -} -// ----- - -func @test_dynamic_reshape(%arg0 : tensor<4x?xi32>) -> () { - // CHECK: %0 = "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x?xi32>) -> tensor<16xi32> - %0 = "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x?xi32>) -> tensor - - // CHECK: %1 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x?xi32>) -> tensor - %1 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x?xi32>) -> tensor - - // CHECK: %2 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x?xi32>) -> tensor<2x?xi32> - %2 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x?xi32>) -> tensor - - return -} -