diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h --- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h @@ -72,6 +72,13 @@ return dynamicDims; } +/// Common code to create the reshape op where necessary to make the rank of two +/// values equal. input1 and input2 will be updated when the rank has +/// changed. The caller is expected to use these to rewrite the original +/// operator with the RESHAPE now in the graph. +LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, + Value &input1, Value &input2); + } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Pass/Pass.h" using namespace mlir; @@ -28,6 +29,7 @@ PatternRewriter &rewriter) const override { Value input = op.getInput(); Value weight = op.getWeight(); + ShapedType inputType = input.getType().cast(); ShapedType weightType = weight.getType().cast(); ShapedType resultType = op.getOutput().getType().cast(); @@ -77,7 +79,9 @@ if (zp == 0) return val; auto ety = val.getType().cast().getElementType(); - auto zpTy = RankedTensorType::get({}, ety); + std::vector shape(val.getType().cast().getRank(), + 1); + auto zpTy = RankedTensorType::get(shape, ety); auto zpAttr = DenseElementsAttr::get(zpTy, rewriter.getIntegerAttr(ety, zp)); auto zpVal = rewriter.create(op.getLoc(), zpTy, zpAttr); @@ -127,6 +131,11 @@ auto mulShapeType = RankedTensorType::get( mulShape, weight.getType().dyn_cast().getElementType()); + + if (EqualizeRanks(rewriter, op.getLoc(), input, weight).failed()) { + return failure(); + } + Value mulValue = rewriter .create(op.getLoc(), mulShapeType, input, weight, /*shift=*/0) @@ -137,14 +146,18 @@ auto outputShapeType = RankedTensorType::get( outputShape, input.getType().dyn_cast().getElementType()); - auto outputValue = rewriter.create( + Value outputValue = rewriter.create( op.getLoc(), outputShapeType, mulValue, rewriter.getDenseI64ArrayAttr(outputShape)); + Value bias = op.getBias(); + if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) { + return failure(); + } + // Add in the bias. rewriter - .replaceOpWithNewOp(op, outputShapeType, outputValue, - op.getBias()) + .replaceOpWithNewOp(op, outputShapeType, outputValue, bias) .getResult(); return success(); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" #include "mlir/Pass/Pass.h" @@ -365,10 +366,14 @@ Value resultPaddingVal = createOpAndInfer( rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr); - auto resultPad = createOpAndInfer( + Value resultPad = createOpAndInfer( rewriter, loc, UnrankedTensorType::get(resultETy), slice, resultPaddingVal); + if (EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) { + return failure(); + } + rewriter.replaceOpWithNewOp(op, op.getType(), resultPad, bias); return success(); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -28,60 +29,17 @@ using namespace mlir; using namespace mlir::tosa; -/// There are two potential ways implementing broadcast: -/// a. https://www.tensorflow.org/xla/broadcasting#formal_definition -/// b. https://numpy.org/doc/stable/user/basics.broadcasting.html -/// This pass implements b (numpy style) now. - -/// In this pass, we insert RESHAPE operators to increase the rank of the -/// lower rank operand as a first step in the broadcasting process. The TOSA -/// operators that support broadcast require that the rank of the operands -/// are equal. - -// Examples: -// If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c]. -// If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c]. -// If lower=[a], higher=[a, a], [a] reshaped into [1, a]. -// If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. -// If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. - -static LogicalResult -computeReshapeOutput(ArrayRef higherRankShape, - ArrayRef lowerRankShape, - SmallVectorImpl &reshapeOutputShape) { - // Initialize new shapes with [1] * higherRank. - int64_t higherRank = higherRankShape.size(); - int64_t lowerRank = lowerRankShape.size(); - - reshapeOutputShape.assign(higherRank, 1); - - int64_t higherRankDim; - int64_t lowerRankDim; - - for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0; - i--, j--) { - higherRankDim = higherRankShape[i]; - lowerRankDim = lowerRankShape[j]; - - if (lowerRankDim == 1 && higherRankDim > 1) - reshapeOutputShape[i] = 1; - else if ((lowerRankDim > 1 && higherRankDim == 1) || - (lowerRankDim == higherRankDim)) - reshapeOutputShape[i] = lowerRankDim; - else if (higherRankDim != lowerRankDim) - return failure(); - } - return success(); -} +namespace { /// Common code to create the reshape op where necessary to make the rank of the /// operations equal. input1 and input2 will be updated when the rank has /// changed. The caller is expected to use these to rewrite the original /// operator with the RESHAPE now in the graph. -static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, - Location loc, - RankedTensorType outputType, - Value &input1, Value &input2) { +/// return failure when (1) no reshape needed, or (2) output_type is specified +/// and it has different rank +LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc, + RankedTensorType outputType, Value &input1, + Value &input2) { auto input1Ty = input1.getType().dyn_cast(); auto input2Ty = input2.getType().dyn_cast(); @@ -96,54 +54,28 @@ return rewriter.notifyMatchFailure(loc, "cannot rewrite as its already correct"); - Value higherTensorValue, lowerTensorValue; - if (input1Rank > input2Rank) { - higherTensorValue = input1; - lowerTensorValue = input2; - } else { - higherTensorValue = input2; - lowerTensorValue = input1; + Value input1_copy = input1; + Value input2_copy = input2; + if (EqualizeRanks(rewriter, loc, input1_copy, input2_copy).failed()) { + return rewriter.notifyMatchFailure(loc, "failed to reshape inputs"); } - ArrayRef higherRankShape = - higherTensorValue.getType().cast().getShape(); - ArrayRef lowerRankShape = - lowerTensorValue.getType().cast().getShape(); - - SmallVector reshapeOutputShape; - - if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape) - .failed()) - return rewriter.notifyMatchFailure(loc, "fail to compute a reshape type"); - - auto reshapeInputType = lowerTensorValue.getType().cast(); - auto reshapeOutputType = RankedTensorType::get( - ArrayRef(reshapeOutputShape), reshapeInputType.getElementType()); - // Verify the rank agrees with the output type if the output type is ranked. if (outputType) { - if (outputType.getShape().size() != reshapeOutputShape.size() || - outputType.getShape().size() != higherRankShape.size()) + if (outputType.getRank() != + input1_copy.getType().cast().getRank() || + outputType.getRank() != + input2_copy.getType().cast().getRank()) return rewriter.notifyMatchFailure( loc, "the reshaped type doesn't agrees with the ranked output type"); } - auto reshapeLower = rewriter.create( - loc, reshapeOutputType, lowerTensorValue, - rewriter.getDenseI64ArrayAttr(reshapeOutputShape)); - - if (input1Rank > input2Rank) { - input1 = higherTensorValue; - input2 = reshapeLower.getResult(); - } else { - input1 = reshapeLower.getResult(); - input2 = higherTensorValue; - } + input1 = input1_copy; + input2 = input2_copy; return success(); } -namespace { template struct ConvertTosaOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -268,8 +200,10 @@ int32_t result1Rank = input1.getType().cast().getRank(); int32_t result2Rank = input2.getType().cast().getRank(); int32_t result3Rank = input3.getType().cast().getRank(); + int32_t outputRank = outputType.getRank(); - if ((result1Rank != result2Rank) || (result2Rank != result3Rank)) + if ((result1Rank != result2Rank) || (result2Rank != result3Rank) || + (result1Rank != outputRank)) return rewriter.notifyMatchFailure( tosaOp, "not all ranks are aligned with each other"); diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" using namespace mlir; using namespace mlir::tosa; @@ -60,3 +61,96 @@ APInt intMax = APInt::getSignedMaxValue(bitwidth); return value >= intMin.getSExtValue() && value <= intMax.getSExtValue(); } + +namespace { +// Given two tensors of high and low ranks, derive the output shape +// to reshape the lower rank to. +// Examples: +// If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c]. +// If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c]. +// If lower=[a], higher=[a, a], [a] reshaped into [1, a]. +// If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. +// If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. +LogicalResult +computeReshapeOutput(ArrayRef higherRankShape, + ArrayRef lowerRankShape, + SmallVectorImpl &reshapeOutputShape) { + // Initialize new shapes with [1] * higherRank. + int64_t higherRank = higherRankShape.size(); + int64_t lowerRank = lowerRankShape.size(); + + reshapeOutputShape.assign(higherRank, 1); + + int64_t higherRankDim; + int64_t lowerRankDim; + + for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0; + i--, j--) { + higherRankDim = higherRankShape[i]; + lowerRankDim = lowerRankShape[j]; + + if (lowerRankDim == 1 && higherRankDim > 1) + reshapeOutputShape[i] = 1; + else if ((lowerRankDim > 1 && higherRankDim == 1) || + (lowerRankDim == higherRankDim)) + reshapeOutputShape[i] = lowerRankDim; + else if (higherRankDim != lowerRankDim) + return failure(); + } + return success(); +} +} // namespace + +LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc, + Value &input1, Value &input2) { + auto input1Ty = input1.getType().dyn_cast(); + auto input2Ty = input2.getType().dyn_cast(); + + if (!input1Ty || !input2Ty) { + return failure(); + } + + int64_t input1Rank = input1Ty.getRank(); + int64_t input2Rank = input2Ty.getRank(); + + if (input1Rank == input2Rank) + return success(); + + Value higherTensorValue, lowerTensorValue; + if (input1Rank > input2Rank) { + higherTensorValue = input1; + lowerTensorValue = input2; + } else { + higherTensorValue = input2; + lowerTensorValue = input1; + } + + ArrayRef higherRankShape = + higherTensorValue.getType().cast().getShape(); + ArrayRef lowerRankShape = + lowerTensorValue.getType().cast().getShape(); + + SmallVector reshapeOutputShape; + + if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape) + .failed()) + return failure(); + + auto reshapeInputType = lowerTensorValue.getType().cast(); + auto reshapeOutputType = RankedTensorType::get( + ArrayRef(reshapeOutputShape), reshapeInputType.getElementType()); + + auto reshapeLower = rewriter.create( + loc, reshapeOutputType, lowerTensorValue, + rewriter.getDenseI64ArrayAttr(reshapeOutputShape)); + + if (input1Rank > input2Rank) { + input1 = higherTensorValue; + input2 = reshapeLower.getResult(); + } else { + input1 = reshapeLower.getResult(); + input2 = higherTensorValue; + } + + return success(); +} diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir --- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir @@ -7,13 +7,17 @@ // CHECK-NOT: "tosa.depthwise_conv2d" // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) <{new_shape = array} // CHECK-SAME: -> tensor<4x10x10x2x1xf32> - // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %arg1) + // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) <{new_shape = array} + // CHECK-SAME: -> tensor<1x1x1x2x3xf32> + // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]]) // CHECK-SAME: -> tensor<4x10x10x2x3xf32> // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array} // CHECK-SAME: -> tensor<4x10x10x6xf32> - // CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2) + // CHECK: %[[VAR4:.*]] = "tosa.reshape"(%arg2) <{new_shape = array} + // CHECK-SAME: -> tensor<1x1x1x6xf32> + // CHECK: %[[VAR5:.*]] = "tosa.add"(%[[VAR3]], %[[VAR4]]) // CHECK-SAME: -> tensor<4x10x10x6xf32> - // CHECK: return %[[VAR4]] + // CHECK: return %[[VAR5]] %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> return %0 : tensor<4x10x10x6xf32> } @@ -22,16 +26,18 @@ // CHECK-LABEL: @depthwise_conv2d_as_mul_q func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> { - // CHECK: %[[iZp:.+]] = "tosa.const"() <{value = dense<7> : tensor} - // CHECK: %[[wZp:.+]] = "tosa.const"() <{value = dense<11> : tensor} + // CHECK: %[[iZp:.+]] = "tosa.const"() <{value = dense<7> : tensor<1x1x1x1x1xi32>} + // CHECK: %[[wZp:.+]] = "tosa.const"() <{value = dense<11> : tensor<1x1x1x1xi32>} // CHECK: %[[rIn:.+]] = "tosa.reshape"(%arg0) <{new_shape = array} // CHECK: %[[cIn:.+]] = "tosa.cast"(%[[rIn]]) : (tensor<4x10x10x2x1xi8>) -> tensor<4x10x10x2x1xi32> // CHECK: %[[cWe:.+]] = "tosa.cast"(%arg1) : (tensor<1x1x2x3xi8>) -> tensor<1x1x2x3xi32> // CHECK: %[[sIn:.+]] = "tosa.sub"(%[[cIn]], %[[iZp]]) // CHECK: %[[sWe:.+]] = "tosa.sub"(%[[cWe]], %[[wZp]]) - // CHECK: %[[mul:.+]] = "tosa.mul"(%[[sIn]], %[[sWe]]) <{shift = 0 : i32} + // CHECK: %[[resWe:.+]] = "tosa.reshape"(%[[sWe]]) <{new_shape = array} + // CHECK: %[[mul:.+]] = "tosa.mul"(%[[sIn]], %[[resWe]]) <{shift = 0 : i32} // CHECK: %[[reO:.+]] = "tosa.reshape"(%[[mul]]) <{new_shape = array} - // CHECK: %[[add:.+]] = "tosa.add"(%[[reO]], %arg2) + // CHECK: %[[reArg2:.+]] = "tosa.reshape"(%arg2) <{new_shape = array} + // CHECK: %[[add:.+]] = "tosa.add"(%[[reO]], %[[reArg2]]) %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = array, stride = array, dilation = array, quantization_info = #tosa.conv_quant} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32> return %0 : tensor<4x10x10x6xi32> } @@ -44,9 +50,11 @@ // CHECK: %[[zero:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor} // CHECK: %[[reIn:.+]] = "tosa.reshape"(%arg0) <{new_shape = array} // CHECK: %[[padded:.+]] = "tosa.pad"(%[[reIn]], %[[pad]], %[[zero]]) : (tensor<4x10x10x2x1xf32>, tensor<5x2xi64>, tensor) -> tensor<4x12x12x2x1xf32> - // CHECK: %[[mul:.+]] = "tosa.mul"(%3, %arg1) <{shift = 0 : i32} + // CHECK: %[[reArg1:.+]] = "tosa.reshape"(%arg1) <{new_shape = array} + // CHECK: %[[mul:.+]] = "tosa.mul"(%3, %[[reArg1]]) <{shift = 0 : i32} // CHECK: %[[reOut:.+]] = "tosa.reshape"(%[[mul]]) <{new_shape = array} - // CHECK: %[[add:.+]] = "tosa.add"(%[[reOut]], %arg2) + // CHECK: %[[reArg2:.+]] = "tosa.reshape"(%arg2) <{new_shape = array} + // CHECK: %[[add:.+]] = "tosa.add"(%[[reOut]], %[[reArg2]]) %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x12x12x6xf32> return %0 : tensor<4x12x12x6xf32> } diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir --- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir @@ -28,7 +28,7 @@ func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x21x26x5xi32>) { // CHECK-DAG: %[[REV0:.+]] = "tosa.reverse"(%0) <{axis = 2 : i64} // CHECK-DAG: %[[REV1:.+]] = "tosa.reverse"(%arg1) <{axis = 1 : i64} - // CHECK: "tosa.conv2d"(%arg0, %1, %arg2) + // CHECK: "tosa.conv2d"(%arg0, %1, %arg2) // CHECK-SAME: dilation = array, pad = array, // CHECK-SAME: quantization_info = #tosa.conv_quant, stride = array} %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) { @@ -65,7 +65,8 @@ // CHECK-DAG: %[[TRANS_OUT:.+]] = "tosa.transpose"(%[[RESHAPE_OUT_1]], %[[TRANS2]]) // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = "tosa.reshape"(%[[TRANS_OUT]]) <{new_shape = array} // CHECK-DAG: %[[SLICE:.+]] = "tosa.slice"(%[[RESHAPE_OUT_2]]) <{size = array, start = array} - // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %arg2) + // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = "tosa.reshape"(%arg2) <{new_shape = array} + // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %[[RESHAPE_ARG2]]) %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array, out_shape = array, stride = array} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32> %1 = tensor.cast %0 : tensor<2x35x47x5xf32> to tensor<2x?x?x5xf32> return %1 : tensor<2x?x?x5xf32> @@ -97,8 +98,9 @@ // CHECK-DAG: %[[TRANS_OUT:.+]] = "tosa.transpose"(%[[RESHAPE_OUT_1]], %[[TRANS2]]) // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = "tosa.reshape"(%[[TRANS_OUT]]) <{new_shape = array} // CHECK-DAG: %[[SLICE:.+]] = "tosa.slice"(%[[RESHAPE_OUT_2]]) <{size = array, start = array} - // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %arg2) - %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) <{out_pad = array, quantization_info = #tosa.conv_quant, out_shape = array, stride = array}> : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32> + // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = "tosa.reshape"(%arg2) <{new_shape = array} + // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %[[RESHAPE_ARG2]]) + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array, quantization_info = #tosa.conv_quant, out_shape = array, stride = array} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32> return %0 : tensor<2x35x47x5xi32> } @@ -106,14 +108,14 @@ // CHECK-LABEL: @transpose_conv2d_strided_overpad func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : tensor<1x2x1x1xi8>, %arg2 : tensor<1xi32>) -> (tensor<1x19x2x1xi32>) { - // CHECK: %[[WEIGHT_PAD:.+]] = "tosa.const"() + // CHECK: %[[WEIGHT_PAD:.+]] = "tosa.const"() // CHECK-SAME{literal}: value = dense<[[0, 0], [0, 0], [0, 1], [0, 0]]> : tensor<4x2xi32> // CHECK: %[[WEIGHT_PERMS:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>} - // CHECK: %[[INPUT_PAD:.+]] = "tosa.const"() + // CHECK: %[[INPUT_PAD:.+]] = "tosa.const"() // CHECK-SAME{literal}: value = dense<[[0, 0], [1, 1], [0, 0], [0, 0]]> : tensor<4x2xi32>} // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<2xi32>} // CHECK: %[[RESULT_PERMS:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>} - // CHECK: %[[RESULT_PAD:.+]] = "tosa.const"() + // CHECK: %[[RESULT_PAD:.+]] = "tosa.const"() // CHECK-SAME{literal}: value = dense<[[0, 0], [2, 0], [0, 0], [0, 0]]> : tensor<4x2xi32>} // CHECK: %[[PAD_WEIGHT:.+]] = "tosa.pad"(%arg1, %[[WEIGHT_PAD]]) <{quantization_info = #tosa.pad_quant} // CHECK: %[[RESHAPE_WEIGHT_0:.+]] = "tosa.reshape"(%[[PAD_WEIGHT]]) <{new_shape = array} @@ -121,13 +123,14 @@ // CHECK: %[[RESHAPE_WEIGHT_1:.+]] = "tosa.reshape"(%[[TRANSPOSE_WEIGHT]]) <{new_shape = array} // CHECK: %[[REVERSE:.+]] = "tosa.reverse"(%[[RESHAPE_WEIGHT_1]]) <{axis = 1 : i64} // CHECK: %[[PAD_INPUT:.+]] = "tosa.pad"(%arg0, %[[INPUT_PAD]]) <{quantization_info = #tosa.pad_quant} - // CHECK: %[[CONV:.+]] = "tosa.conv2d"(%[[PAD_INPUT]], %[[REVERSE]], %[[ZERO]]) + // CHECK: %[[CONV:.+]] = "tosa.conv2d"(%[[PAD_INPUT]], %[[REVERSE]], %[[ZERO]]) // CHECK-SAME{literal}: dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = #tosa.conv_quant, stride = [1, 1]} // CHECK: %[[RESHAPE_RESULT_0:.+]] = "tosa.reshape"(%[[CONV]]) <{new_shape = array} // CHECK: %[[TRANSPOSE_RESULT:.+]] = "tosa.transpose"(%[[RESHAPE_RESULT_0]], %[[RESULT_PERMS]]) // CHECK: %[[RESHAPE_RESULT_1:.+]] = "tosa.reshape"(%[[TRANSPOSE_RESULT]]) <{new_shape = array} // CHECK: %[[PAD_RESULT:.+]] = "tosa.pad"(%[[RESHAPE_RESULT_1]], %[[RESULT_PAD]]) - // CHECK: %[[ADD:.+]] = "tosa.add"(%[[PAD_RESULT]], %arg2) + // CHECK: %[[RESHAPE_ARG2:.+]] = "tosa.reshape"(%arg2) <{new_shape = array} + // CHECK: %[[ADD:.+]] = "tosa.add"(%[[PAD_RESULT]], %[[RESHAPE_ARG2]]) %2 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) { out_pad = array, out_shape = array,