diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -19,6 +19,7 @@ namespace mlir { namespace tosa { +std::unique_ptr createTosaDecomposeTransposeConvPass(); std::unique_ptr createTosaInferShapesPass(); std::unique_ptr createTosaMakeBroadcastablePass(); std::unique_ptr createTosaTestQuantUtilAPIPass(); diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -15,6 +15,21 @@ include "mlir/Pass/PassBase.td" +def TosaDecomposeTransposeConv : FunctionPass<"tosa-decompose-transpose-conv"> { + let summary = "Deompose transpose convolutiions into standard convolutions."; + let description = [{ + Pass that uses shape manipulation and convolution operations to transform + a transpose convolution into a regular convolution. + }]; + + let constructor = "createTosaDecomposeTransposeConvPass()"; + let dependentDialects = [ + "StandardOpsDialect", + "tensor::TensorDialect", + "tosa::TosaDialect", + ]; +} + def TosaInferShapes : FunctionPass<"tosa-infer-shapes"> { let summary = "Propagate shapes across TOSA operations"; let description = [{ diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1384,77 +1384,6 @@ } }; -class TransposeConvConverter - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(tosa::TransposeConv2DOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - Location loc = op->getLoc(); - Value input = op->getOperand(0); - Value weight = op->getOperand(1); - Value bias = op->getOperand(2); - - ShapedType inputTy = input.getType().cast(); - ShapedType weightTy = weight.getType().cast(); - ShapedType biasTy = bias.getType().cast(); - ShapedType resultTy = op->getResult(0).getType().cast(); - - llvm::SmallVector pad; - llvm::SmallVector stride; - llvm::SmallVector dilation; - - getValuesFromIntArrayAttribute(op.out_pad().cast(), pad); - getValuesFromIntArrayAttribute(op.stride().cast(), stride); - getValuesFromIntArrayAttribute(op.dilation().cast(), dilation); - - // If striding is all 1 we can modify padding and reverse the kernel along - // the x/y direction to make it a regular convolution. This is much simpler - // then handling striding.... - if (llvm::all_of(stride, [](int64_t v) { return v == 1; })) { - if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || - !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) - return failure(); - - int64_t kernelHeight = (weightTy.getDimSize(1) - 1) * dilation[0] + 1; - int64_t kernelWidth = (weightTy.getDimSize(2) - 1) * dilation[1] + 1; - int64_t requiredInputHeight = resultTy.getDimSize(1) + kernelHeight - 1; - int64_t requiredInputWidth = resultTy.getDimSize(2) + kernelWidth - 1; - - llvm::SmallVector convPad(4, 0); - convPad[0] = kernelHeight - 1 - pad[0]; - convPad[2] = kernelWidth - 1 - pad[1]; - convPad[1] = requiredInputHeight - convPad[0] - inputTy.getDimSize(1); - convPad[3] = requiredInputWidth - convPad[2] - inputTy.getDimSize(2); - - auto reverse1 = rewriter.create( - loc, weightTy, weight, rewriter.getI64IntegerAttr(1)); - auto reverse2 = rewriter.create( - loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2)); - - Value conv2d; - if (op.quantization_info().hasValue()) { - conv2d = rewriter.create( - loc, resultTy, input, reverse2, bias, - rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride), - rewriter.getI64ArrayAttr(dilation), - op.quantization_info().getValue()); - } else { - conv2d = rewriter.create( - loc, resultTy, input, reverse2, bias, - rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride), - rewriter.getI64ArrayAttr(dilation)); - } - - rewriter.replaceOp(op, conv2d); - return success(); - } - - return failure(); - } -}; - class MatMulConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -3188,7 +3117,6 @@ ConcatConverter, ConvConverter, DepthwiseConvConverter, - TransposeConvConverter, GatherConverter, PadConverter, ReshapeConverterCollapse, diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -50,6 +50,7 @@ target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); + target.addLegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRTosaTransforms + TosaDecomposeTransposeConv.cpp TosaInferShapes.cpp TosaMakeBroadcastable.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -0,0 +1,390 @@ +//===- TosaDecomposeTransposeConv.cpp +//------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Insert reshape to binary op's input if needed to match rank +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/IR//TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/PassDetail.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +template +static void getValuesFromIntArrayAttribute(ArrayAttr attr, + SmallVector &arrayValues) { + for (Attribute val : attr.getValue()) { + arrayValues.push_back(val.cast().getValue().getSExtValue()); + } +} + +template +TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty, + Args &&...args) { + auto op = rewriter.create(loc, result_ty, args...); + + InferShapedTypeOpInterface shapeInterface = + dyn_cast(op.getOperation()); + if (!shapeInterface) + return op; + + SmallVector returnedShapes; + if (shapeInterface + .inferReturnTypeComponents(op.getContext(), op.getLoc(), + op->getOperands(), op->getAttrDictionary(), + op->getRegions(), returnedShapes) + .failed()) + return op; + + // We need to use the element type of the existing result type to generate + // the new result shaped type. This is because rescale can include a cast to + // different bit-width types and does not have a TypeAttr to define the + // target type. + auto result = op->getResult(0); + auto predictedShape = returnedShapes[0]; + auto currentKnowledge = + mlir::tosa::ValueKnowledge::getKnowledgeFromType(result_ty); + + // Compute the knowledge based on the inferred type. + auto inferredKnowledge = + mlir::tosa::ValueKnowledge::getPessimisticValueState(); + inferredKnowledge.dtype = result_ty.cast().getElementType(); + inferredKnowledge.hasRank = predictedShape.hasRank(); + if (predictedShape.hasRank()) { + for (auto dim : predictedShape.getDims()) { + inferredKnowledge.sizes.push_back(dim); + } + } + + // Compute the new type based on the joined version. + auto newKnowledge = + mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge); + auto new_ty = newKnowledge.getType(); + result.setType(new_ty); + return op; +} + +class TransposeConvDilatedConverter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op, + PatternRewriter &rewriter) const final { + Location loc = op->getLoc(); + Value input = op->getOperand(0); + Value weight = op->getOperand(1); + Value bias = op->getOperand(2); + + ShapedType inputTy = input.getType().cast(); + ShapedType weightTy = weight.getType().cast(); + ShapedType biasTy = bias.getType().cast(); + ShapedType resultTy = op->getResult(0).getType().cast(); + + llvm::SmallVector pad; + llvm::SmallVector stride; + llvm::SmallVector dilation; + + getValuesFromIntArrayAttribute(op.out_pad().cast(), pad); + getValuesFromIntArrayAttribute(op.stride().cast(), stride); + getValuesFromIntArrayAttribute(op.dilation().cast(), dilation); + + // If striding is all 1 we can modify padding and reverse the kernel along + // the x/y direction to make it a regular convolution. This is much simpler + // then handling striding.... + if (llvm::any_of(stride, [](int64_t v) { return v != 1; })) + return failure(); + + if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || + !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) + return failure(); + + int64_t kernelHeight = (weightTy.getDimSize(1) - 1) * dilation[0] + 1; + int64_t kernelWidth = (weightTy.getDimSize(2) - 1) * dilation[1] + 1; + int64_t requiredInputHeight = resultTy.getDimSize(1) + kernelHeight - 1; + int64_t requiredInputWidth = resultTy.getDimSize(2) + kernelWidth - 1; + + llvm::SmallVector convPad(4, 0); + convPad[0] = kernelHeight - 1 - pad[0]; + convPad[2] = kernelWidth - 1 - pad[1]; + convPad[1] = requiredInputHeight - convPad[0] - inputTy.getDimSize(1); + convPad[3] = requiredInputWidth - convPad[2] - inputTy.getDimSize(2); + + auto reverse1 = rewriter.create( + loc, weightTy, weight, rewriter.getI64IntegerAttr(1)); + auto reverse2 = rewriter.create( + loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2)); + + Value conv2d; + if (op.quantization_info().hasValue()) { + conv2d = rewriter.create( + loc, resultTy, input, reverse2, bias, + rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride), + rewriter.getI64ArrayAttr(dilation), + op.quantization_info().getValue()); + } else { + conv2d = rewriter.create( + loc, resultTy, input, reverse2, bias, + rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride), + rewriter.getI64ArrayAttr(dilation)); + } + + rewriter.replaceOp(op, conv2d); + return success(); + } +}; + +class TransposeConvStridedConverter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op, + PatternRewriter &rewriter) const final { + Location loc = op->getLoc(); + Value input = op->getOperand(0); + Value weight = op->getOperand(1); + Value bias = op->getOperand(2); + + ShapedType inputTy = input.getType().cast(); + ShapedType weightTy = weight.getType().cast(); + ShapedType biasTy = bias.getType().cast(); + ShapedType resultTy = op->getResult(0).getType().cast(); + + Type inputETy = inputTy.getElementType(); + Type weightETy = weightTy.getElementType(); + Type biasETy = biasTy.getElementType(); + Type resultETy = resultTy.getElementType(); + + llvm::SmallVector pad; + llvm::SmallVector stride; + llvm::SmallVector dilation; + + getValuesFromIntArrayAttribute(op.out_pad().cast(), pad); + getValuesFromIntArrayAttribute(op.stride().cast(), stride); + getValuesFromIntArrayAttribute(op.dilation().cast(), dilation); + + // If striding is all 1 we can modify padding and reverse the kernel along + // the x/y direction to make it a regular convolution. This is much simpler + // then handling striding.... + if (llvm::any_of(dilation, [](int64_t v) { return v != 1; })) + return failure(); + + // If strides are all 1 we dont need to use this one. + if (llvm::all_of(stride, [](int64_t v) { return v == 1; })) + return failure(); + + if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || + !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) + return failure(); + + int64_t batch = inputTy.getDimSize(0); + + int64_t outputChannels = weightTy.getDimSize(0); + int64_t weightHeight = weightTy.getDimSize(1); + int64_t weightWidth = weightTy.getDimSize(2); + int64_t inputChannels = weightTy.getDimSize(3); + + // Pad the weight so that it is modulo of the striding. + llvm::SmallVector weightPadding = {0, 0, 0, 0, 0, 0, 0, 0}; + weightPadding[3] = + weightHeight % stride[0] ? stride[0] - weightHeight % stride[0] : 0; + weightPadding[5] = + weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0; + DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get( + RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding); + Value weightPaddingVal = CreateOpAndInfer( + rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr); + + if (op.quantization_info().hasValue()) { + auto quantInfo = op.quantization_info().getValue(); + weight = CreateOpAndInfer( + rewriter, loc, UnrankedTensorType::get(weightETy), weight, + weightPaddingVal, nullptr, + PadOpQuantizationAttr::get(quantInfo.weight_zp(), + rewriter.getContext())); + + } else { + weight = CreateOpAndInfer(rewriter, loc, + UnrankedTensorType::get(weightETy), + weight, weightPaddingVal); + } + + weightTy = weight.getType().cast(); + weightHeight = weightTy.getDimSize(1); + weightWidth = weightTy.getDimSize(2); + + // Split out the width / height by the stride dimensions. + llvm::SmallVector weightReshapeDims0 = { + outputChannels, weightHeight / stride[0], + stride[0], weightWidth / stride[1], + stride[1], inputChannels}; + weight = CreateOpAndInfer( + rewriter, loc, UnrankedTensorType::get(weightETy), weight, + rewriter.getI64ArrayAttr(weightReshapeDims0)); + + // Transpose the factored-out stride to the output channels. + Value transposeWeightVal = rewriter.create( + loc, RankedTensorType::get({6}, rewriter.getI32Type()), + rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5})); + + weight = CreateOpAndInfer( + rewriter, loc, UnrankedTensorType::get(weightETy), weight, + transposeWeightVal); + + // Collapse the strides and output channels into a single dimension. + llvm::SmallVector weightReshapeDims1 = { + outputChannels * stride[0] * stride[1], weightHeight / stride[0], + weightWidth / stride[1], inputChannels}; + weight = CreateOpAndInfer( + rewriter, loc, UnrankedTensorType::get(weightETy), weight, + rewriter.getI64ArrayAttr(weightReshapeDims1)); + ShapedType restridedWeightTy = weight.getType().cast(); + + weight = CreateOpAndInfer( + rewriter, loc, UnrankedTensorType::get(weightETy), weight, + rewriter.getI64IntegerAttr(1)); + weight = CreateOpAndInfer( + rewriter, loc, UnrankedTensorType::get(weightETy), weight, + rewriter.getI64IntegerAttr(2)); + + // We need to pad the input far enough that we can pull all values. + llvm::SmallVector inputPadding = {0, 0, 0, 0, 0, 0, 0, 0}; + inputPadding[2] += restridedWeightTy.getDimSize(1) - 1; + inputPadding[3] += restridedWeightTy.getDimSize(1) - 1; + inputPadding[4] += restridedWeightTy.getDimSize(2) - 1; + inputPadding[5] += restridedWeightTy.getDimSize(2) - 1; + + DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get( + RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding); + + Value inputPaddingVal = CreateOpAndInfer( + rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr); + + if (op.quantization_info().hasValue()) { + auto quantInfo = op.quantization_info().getValue(); + input = CreateOpAndInfer( + rewriter, loc, UnrankedTensorType::get(inputETy), input, + inputPaddingVal, nullptr, + PadOpQuantizationAttr::get(quantInfo.input_zp(), + rewriter.getContext())); + } else { + input = CreateOpAndInfer(rewriter, loc, + UnrankedTensorType::get(inputETy), + input, inputPaddingVal); + } + + // We use a zero bias as we need to broadcast the bias. + auto zeroBias = rewriter.create( + loc, + RankedTensorType::get({outputChannels * stride[0] * stride[1]}, + biasETy), + DenseElementsAttr::get( + RankedTensorType::get({outputChannels * stride[0] * stride[1]}, + biasETy), + rewriter.getZeroAttr(biasETy))); + + // Perform the convolution using the zero bias. + Value conv2d; + if (op.quantization_info().hasValue()) { + conv2d = CreateOpAndInfer( + rewriter, loc, UnrankedTensorType::get(resultETy), input, + weight, zeroBias, + /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}), + /*stride=*/rewriter.getI64ArrayAttr({1, 1}), + /*dilation=*/rewriter.getI64ArrayAttr({1, 1}), + op.quantization_info().getValue()) + .getResult(); + } else { + conv2d = CreateOpAndInfer( + rewriter, loc, UnrankedTensorType::get(resultETy), input, + weight, zeroBias, + /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}), + /*stride=*/rewriter.getI64ArrayAttr({1, 1}), + /*dilation=*/rewriter.getI64ArrayAttr({1, 1})) + .getResult(); + } + + // Factor the resulting width / height. + ShapedType convTy = conv2d.getType().cast(); + Type convETy = convTy.getElementType(); + + int64_t convHeight = convTy.getDimSize(1); + int64_t convWidth = convTy.getDimSize(2); + + // Factor striding out of the convolution result. + llvm::SmallVector convReshapeDims0 = { + batch, convHeight, convWidth, stride[0], stride[1], outputChannels}; + conv2d = CreateOpAndInfer( + rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, + rewriter.getI64ArrayAttr(convReshapeDims0)); + + // Transpose the factored-out stride to the output channels. + Value transposeConvVal = rewriter.create( + loc, RankedTensorType::get({6}, rewriter.getI32Type()), + rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5})); + + conv2d = CreateOpAndInfer( + rewriter, loc, UnrankedTensorType::get(convETy), conv2d, + transposeConvVal); + + // Fuse striding behavior back into width / height. + llvm::SmallVector convReshapeDims1 = { + batch, convHeight * stride[0], convWidth * stride[1], outputChannels}; + conv2d = CreateOpAndInfer( + rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, + rewriter.getI64ArrayAttr(convReshapeDims1)); + + // Slice out the final result. + llvm::SmallVector sliceBegin = {0, 0, 0, 0}; + llvm::SmallVector sliceSize(resultTy.getShape().begin(), + resultTy.getShape().begin()); + sliceBegin[1] = pad[0]; + sliceBegin[2] = pad[1]; + + auto slice = CreateOpAndInfer( + rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, + rewriter.getI64ArrayAttr(sliceBegin), + rewriter.getI64ArrayAttr(resultTy.getShape())) + .getResult(); + + auto addBias = + CreateOpAndInfer(rewriter, loc, op.getType(), slice, bias); + + rewriter.replaceOp(op, addBias.getResult()); + + return success(); + } +}; + +/// Pass that enables broadcast by making all input arrays have the same +/// number of dimensions. Insert RESHAPE operations to lower rank operand +struct TosaDecomposeTransposeConv + : public TosaDecomposeTransposeConvBase { +public: + void runOnFunction() override { + auto func = getFunction(); + RewritePatternSet patterns(func.getContext()); + patterns + .insert( + func.getContext()); + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + } +}; +} // end anonymous namespace + +std::unique_ptr mlir::tosa::createTosaDecomposeTransposeConvPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1719,27 +1719,6 @@ return } -// ----- - -// CHECK-LABEL: @transpose_conv -func @transpose_conv(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>, %arg2 : tensor<4xf32>) -> () { - // CHECK: linalg.pad_tensor %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] - // CHECK: linalg.conv_2d_nhwc_hwcf - %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [1, 14, 14, 4], stride = [1, 1]} : (tensor<1x12x12x2xf32>, tensor<4x3x3x2xf32>, tensor<4xf32>) -> tensor<1x14x14x4xf32> - return -} - -// ----- - -// CHECK-LABEL: @transpose_conv_dilated -func @transpose_conv_dilated(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>, %arg2 : tensor<4xf32>) -> () { - // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 4, 4, 0] high[0, 4, 4, 0] - // CHECK: linalg.conv_2d_nhwc_hwcf {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], {{%.+}} : tensor<1x20x20x2xf32>, tensor<3x3x2x4xf32>) - %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [2, 2], out_pad = [0, 0], out_shape = [1, 16, 16, 4], stride = [1, 1]} : (tensor<1x12x12x2xf32>, tensor<4x3x3x2xf32>, tensor<4xf32>) -> tensor<1x16x16x4xf32> - return -} - - // ----- // CHECK-LABEL: @resize_nearest diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir @@ -0,0 +1,97 @@ +// RUN: mlir-opt --split-input-file --tosa-decompose-transpose-conv %s | FileCheck %s + +// CHECK-LABEL: @transpose_conv2d +func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> { + // CHECK: %[[REV1:.+]] = "tosa.reverse"(%arg1) {axis = 1 : i64} + // CHECK: %[[REV2:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64} + // CHECK: "tosa.conv2d"(%arg0, %[[REV2]], %arg2) {dilation = [1, 1], pad = [2, 2, 5, 5], stride = [1, 1]} + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32> + %1 = tensor.cast %0 : tensor<2x18x19x5xf32> to tensor<2x?x?x5xf32> + return %1 : tensor<2x?x?x5xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_conv2d_quantized +func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x18x19x5xi32>) { + // CHECK: %[[REV1:.+]] = "tosa.reverse"(%arg1) {axis = 1 : i64} + // CHECK: %[[REV2:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64} + // CHECK: "tosa.conv2d"(%arg0, %[[REV2]], %arg2) {dilation = [1, 1], pad = [2, 2, 5, 5], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, stride = [1, 1]} + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x18x19x5xi32> + return %0 : tensor<2x18x19x5xi32> +} + +// ---- + +// CHECK-LABEL: @transpose_conv2d_dilated +func @transpose_conv2d_dilated(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> { + // CHECK: %[[REV1:.+]] = "tosa.reverse"(%arg1) {axis = 1 : i64} + // CHECK: %[[REV2:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64} + // CHECK: "tosa.conv2d"(%arg0, %[[REV2]], %arg2) {dilation = [2, 3], pad = [4, 4, 15, 15], stride = [1, 1]} + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [2, 3], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x20x29x5xf32> + %1 = tensor.cast %0 : tensor<2x20x29x5xf32> to tensor<2x?x?x5xf32> + return %1 : tensor<2x?x?x5xf32> +} + +// ---- + +// CHECK-LABEL: @transpose_conv2d_strided +func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> { + // Manipulate the weight matrix to handle striding. + // CHECK-DAG: %[[PADV:.+]] = "tosa.const"() {value = dense<{{\[\[}}0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi32>} + // CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() {value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>} + // CHECK-DAG: %[[PADW:.+]] = "tosa.pad"(%arg1, %[[PADV]]) + // CHECK-DAG: %[[RESW1:.+]] = "tosa.reshape"(%[[PADW]]) {new_shape = [5, 2, 2, 2, 3, 3]} + // CHECK-DAG: %[[TRANS:.+]] = "tosa.transpose"(%[[RESW1]], %[[TRANSV]]) + // CHECK-DAG: %[[RESW2:.+]] = "tosa.reshape"(%[[TRANS]]) {new_shape = [30, 2, 2, 3]} + // CHECK-DAG: %[[REV1:.+]] = "tosa.reverse"(%[[RESW2]]) {axis = 1 : i64} + // CHECK-DAG: %[[NEWWEIGHT:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64} + + // Pad out the input matrix to handle the transpose conv. + // CHECK-DAG: %[[PAD:.+]] = "tosa.const"() {value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>} + // CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() {value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>} + // CHECK-DAG: %[[NEWINPUT:.+]] = "tosa.pad"(%arg0, %[[PAD]]) + + // Manipulate the final shape. + // CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<30xf32>} + // CHECK-DAG: %[[CONV:.+]] = "tosa.conv2d"(%[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]]) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} + // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = "tosa.reshape"(%[[CONV]]) {new_shape = [2, 18, 16, 2, 3, 5]} + // CHECK-DAG: %[[TRANS_OUT:.+]] = "tosa.transpose"(%[[RESHAPE_OUT_1]], %[[TRANS2]]) + // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = "tosa.reshape"(%[[TRANS_OUT]]) {new_shape = [2, 36, 48, 5]} + // CHECK-DAG: %[[SLICE:.+]] = "tosa.slice"(%[[RESHAPE_OUT_2]]) {size = [2, 35, 47, 5], start = [0, 0, 0, 0]} + // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %arg2) + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [2, 3]} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32> + %1 = tensor.cast %0 : tensor<2x35x47x5xf32> to tensor<2x?x?x5xf32> + return %1 : tensor<2x?x?x5xf32> +} + +// ---- + +// CHECK-LABEL: @transpose_conv2d_strided_quantized +func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1: tensor<5x3x5x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x35x47x5xi32>) { + // Manipulate the weight matrix to handle striding. + // CHECK-DAG: %[[PADV:.+]] = "tosa.const"() {value = dense<{{\[\[}}0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi32>} + // CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() {value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>} + // CHECK-DAG: %[[PADW:.+]] = "tosa.pad"(%arg1, %[[PADV]]) {quantization_info = {input_zp = 42 : i32}} + // CHECK-DAG: %[[RESW1:.+]] = "tosa.reshape"(%[[PADW]]) {new_shape = [5, 2, 2, 2, 3, 3]} + // CHECK-DAG: %[[TRANS:.+]] = "tosa.transpose"(%[[RESW1]], %[[TRANSV]]) + // CHECK-DAG: %[[RESW2:.+]] = "tosa.reshape"(%[[TRANS]]) {new_shape = [30, 2, 2, 3]} + // CHECK-DAG: %[[REV1:.+]] = "tosa.reverse"(%[[RESW2]]) {axis = 1 : i64} + // CHECK-DAG: %[[NEWWEIGHT:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64} + + // Pad out the input matrix to handle the transpose conv. + // CHECK-DAG: %[[PAD:.+]] = "tosa.const"() {value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>} + // CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() {value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>} + // CHECK-DAG: %[[NEWINPUT:.+]] = "tosa.pad"(%arg0, %[[PAD]]) {quantization_info = {input_zp = -22 : i32}} + + // Manipulate the final shape. + // CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() {value = dense<0> : tensor<30xi32>} + // CHECK-DAG: %[[CONV:.+]] = "tosa.conv2d"(%[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]]) {dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, stride = [1, 1]} + // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = "tosa.reshape"(%[[CONV]]) {new_shape = [2, 18, 16, 2, 3, 5]} + // CHECK-DAG: %[[TRANS_OUT:.+]] = "tosa.transpose"(%[[RESHAPE_OUT_1]], %[[TRANS2]]) + // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = "tosa.reshape"(%[[TRANS_OUT]]) {new_shape = [2, 36, 48, 5]} + // CHECK-DAG: %[[SLICE:.+]] = "tosa.slice"(%[[RESHAPE_OUT_2]]) {size = [2, 35, 47, 5], start = [0, 0, 0, 0]} + // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %arg2) + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, out_shape = [-1, -1, -1, -1], stride = [2, 3]} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32> + return %0 : tensor<2x35x47x5xi32> +}