diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -21,6 +21,7 @@ std::unique_ptr createTosaInferShapesPass(); std::unique_ptr createTosaMakeBroadcastablePass(); +std::unique_ptr createTosaMaterializePaddingPass(); std::unique_ptr createTosaTestQuantUtilAPIPass(); #define GEN_PASS_REGISTRATION diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -43,4 +43,16 @@ let constructor = "createTosaMakeBroadcastablePass()"; } +def TosaMaterializePadding : FunctionPass<"tosa-materialize-padding"> { + let summary = "TOSA padding operations out of TOSA padded ops"; + let description = [{ + Pass that separates padding behavior out of TOSA operations by inserting + explicit tosa.pad operations. The source TOSA operation is updated to no + longer have ineternal padding behavior. This is useful for running + optimization passes on the padding operations. + }]; + + let constructor = "createTosaMaterializePaddingPass()"; +} + #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1140,7 +1140,6 @@ ShapedType biasTy = bias.getType().cast(); ShapedType resultTy = op->getResult(0).getType().cast(); - Type inputETy = inputTy.getElementType(); Type resultETy = resultTy.getElementType(); auto padAttr = op->getAttr("pad").cast(); @@ -1167,34 +1166,10 @@ auto weightShape = weightTy.getShape(); auto resultShape = resultTy.getShape(); - // Apply padding as necessary. - Attribute zeroAttr = rewriter.getZeroAttr(inputETy); - if (isQuantized) { - auto quantizationInfo = - op->getAttr("quantization_info").cast(); - auto iZp = quantizationInfo.input_zp().getValue().getSExtValue(); - - int64_t intMin = - APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth()) - .getSExtValue(); - int64_t intMax = - APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth()) - .getSExtValue(); - - if (iZp < intMin || iZp > intMax) - return rewriter.notifyMatchFailure( - op, "tosa.depthwise_conv op quantization has zp outside of input " - "range"); - - zeroAttr = rewriter.getIntegerAttr(inputETy, iZp); - } - llvm::SmallVector pad; - pad.resize(2, 0); getValuesFromIntArrayAttribute(padAttr, pad); - pad.resize(pad.size() + 2, 0); - - input = applyPad(loc, input, pad, zeroAttr, rewriter); + if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) + return failure(); // Extract the attributes for convolution. llvm::SmallVector stride, dilation; diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaInferShapes.cpp TosaMakeBroadcastable.cpp + TosaMaterializePadding.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMaterializePadding.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMaterializePadding.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMaterializePadding.cpp @@ -0,0 +1,198 @@ +//===- TosaMaterializePadding.cpp +//------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Materializes padding operations out of TOSA operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR//TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/PassDetail.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +template +static void getValuesFromIntArrayAttribute(ArrayAttr attr, + SmallVector &arrayValues) { + for (Attribute val : attr.getValue()) { + arrayValues.push_back(val.cast().getValue().getSExtValue()); + } +} + +// Creates a TOSA operation and performs shape inference on the individual +// op. This allows shape inference during the TFLite to TOSA lowering. +template +TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty, + Args &&... args) { + auto op = rewriter.create(loc, result_ty, args...); + + InferShapedTypeOpInterface shapeInterface = + dyn_cast(op.getOperation()); + if (!shapeInterface) + return op; + + SmallVector returnedShapes; + if (shapeInterface + .inferReturnTypeComponents(op.getContext(), op.getLoc(), + op->getOperands(), op->getAttrDictionary(), + op->getRegions(), returnedShapes) + .failed()) + return op; + + // We need to use the element type of the existing result type to generate + // the new result shaped type. This is because rescale can include a cast to + // different bit-width types and does not have a TypeAttr to define the + // target type. + auto result = op->getResult(0); + auto predictedShape = returnedShapes[0]; + auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(result_ty); + + // Compute the knowledge based on the inferred type. + auto inferredKnowledge = ValueKnowledge::getPessimisticValueState(); + inferredKnowledge.dtype = result_ty.cast().getElementType(); + inferredKnowledge.hasRank = predictedShape.hasRank(); + if (predictedShape.hasRank()) { + for (auto dim : predictedShape.getDims()) { + inferredKnowledge.sizes.push_back(dim); + } + } + + // Compute the new type based on the joined version. + auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge); + auto new_ty = newKnowledge.getType(); + result.setType(new_ty); + return op; +} + +template +void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op, + Type result_ty, Args &&... args) { + auto result = + CreateOpAndInfer(rewriter, op->getLoc(), result_ty, args...); + rewriter.replaceOp(op, result->getResults()); +} + +class PadDepthwiseConv2d : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op, + PatternRewriter &rewriter) const final { + Location loc = op->getLoc(); + Value input = op->getOperand(0); + TensorType inputTy = input.getType().cast(); + + auto padAttr = op.pad().cast(); + + llvm::SmallVector pad; + pad.resize(2, 0); + getValuesFromIntArrayAttribute(padAttr, pad); + pad.resize(pad.size() + 2, 0); + // No padding is required so we exit. + if (llvm::all_of(pad, [](int64_t p) { return p == 0; })) + return failure(); + + auto newPadAttr = rewriter.getI64TensorAttr(pad); + auto padConst = CreateOpAndInfer( + rewriter, loc, newPadAttr.getType(), newPadAttr); + + // Determine the quantization information. + PadOpQuantizationAttr padOpQuantInfo; + auto quantInfo = op.quantization_info(); + if (quantInfo.hasValue()) { + padOpQuantInfo = tosa::PadOpQuantizationAttr::get( + quantInfo.getValue().input_zp(), op.getContext()); + } + + // Add the explicit padding and substitute for the input. + input = CreateOpAndInfer( + rewriter, loc, UnrankedTensorType::get(inputTy.getElementType()), input, + padConst, padOpQuantInfo); + op.setOperand(0, input); + + // Clear out the padding, as it is materialized externally. + llvm::SmallVector zeroPad = {0, 0, 0, 0}; + op->setAttr("pad", rewriter.getI64ArrayAttr(zeroPad)); + + return success(); + } +}; + +class PadConv2d : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tosa::Conv2DOp op, + PatternRewriter &rewriter) const final { + Location loc = op->getLoc(); + Value input = op->getOperand(0); + TensorType inputTy = input.getType().cast(); + + auto padAttr = op.pad().cast(); + + llvm::SmallVector pad; + pad.resize(2, 0); + getValuesFromIntArrayAttribute(padAttr, pad); + pad.resize(pad.size() + 2, 0); + // No padding is required so we exit. + if (llvm::all_of(pad, [](int64_t p) { return p == 0; })) + return failure(); + + auto newPadAttr = rewriter.getI64TensorAttr(pad); + auto padConst = CreateOpAndInfer( + rewriter, loc, newPadAttr.getType(), newPadAttr); + + // Determine the quantization information. + PadOpQuantizationAttr padOpQuantInfo; + auto quantInfo = op.quantization_info(); + if (quantInfo.hasValue()) { + padOpQuantInfo = tosa::PadOpQuantizationAttr::get( + quantInfo.getValue().input_zp(), op.getContext()); + } + + // Add the explicit padding and substitute for the input. + input = CreateOpAndInfer( + rewriter, loc, UnrankedTensorType::get(inputTy.getElementType()), input, + padConst, padOpQuantInfo); + op.setOperand(0, input); + + // Clear out the padding, as it is materialized externally. + llvm::SmallVector zeroPad = {0, 0, 0, 0}; + op->setAttr("pad", rewriter.getI64ArrayAttr(zeroPad)); + + return success(); + } +}; + +struct TosaMaterializePadding + : public TosaMaterializePaddingBase { +public: + void runOnFunction() override { + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + + if (applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)) + .failed()) { + signalPassFailure(); + } + } +}; + +} // anonymous namespace + +std::unique_ptr mlir::tosa::createTosaMaterializePaddingPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1613,23 +1613,27 @@ // CHECK-LABEL: @depthwise_conv_quant func @depthwise_conv_quant(%arg0 : tensor<1x12x12x4xi8>, %arg1 : tensor<3x3x4x128xi8>, %arg2 : tensor<512xi32>) -> () { - // CHECK: [[PADV:%.+]] = arith.constant -128 - // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] - // CHECK: linalg.yield [[PADV]] - - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 12, 12, 4, 128] + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 10, 10, 4, 128] // CHECK: [[CST0:%.+]] = arith.constant 0 // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]]) - // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 12, 12, 512] + // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 10, 10, 512] // CHECK: [[C128:%.+]] = arith.constant -128 // CHECK: [[C42:%.+]] = arith.constant 42 - // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nhwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], %arg1, [[C128]], [[C42]] : tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[FILL]] : tensor<1x12x12x4x128xi32>) + // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nhwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, [[C128]], [[C42]] : tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[FILL]] : tensor<1x10x10x4x128xi32>) // CHECK: [[COLLAPSED:%.+]] = linalg.tensor_collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] - // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<512xi32>, tensor<1x12x12x512xi32>) outs([[OUT]] : tensor<1x12x12x512xi32>) { + // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<512xi32>, tensor<1x10x10x512xi32>) outs([[OUT]] : tensor<1x10x10x512xi32>) { // CHECK: ^bb0(%arg3: i32, %arg4: i32, %arg5: i32): // no predecessors // CHECK: [[ADD:%.+]] = arith.addi %arg3, %arg4 : i32 // CHECK: linalg.yield [[ADD]] : i32 - // CHECK: } -> tensor<1x12x12x512xi32> + // CHECK: } -> tensor<1x10x10x512xi32> + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, weight_zp = 42 : i32}, stride = [1, 1], dilation = [1, 1] } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x10x10x512xi32> + return +} + +// ----- + +// expected-error @+2 {{failed to legalize operation 'tosa.depthwise_conv2d'}} +func @depthwise_conv_fail(%arg0 : tensor<1x12x12x4xi8>, %arg1 : tensor<3x3x4x128xi8>, %arg2 : tensor<512xi32>) -> () { %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 1, 1], quantization_info = {input_zp = -128 : i32, weight_zp = 42 : i32}, stride = [1, 1], dilation = [1, 1] } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x12x12x512xi32> return } diff --git a/mlir/test/Dialect/Tosa/materialize-padding.mlir b/mlir/test/Dialect/Tosa/materialize-padding.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tosa/materialize-padding.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt --split-input-file --tosa-materialize-padding %s -verify-diagnostics -o -| FileCheck %s + +// CHECK-LABEL: @depthwise_conv +func @depthwise_conv(%arg0 : tensor<1x12x12x4xi8>, %arg1 : tensor<3x3x4x128xi8>, %arg2 : tensor<512xi32>) -> (tensor<1x12x12x512xi32>) { + // CHECK: %[[CONST:.+]] = "tosa.const"() {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xi64>} + // CHECK: %[[PAD:.+]] = "tosa.pad"(%arg0, %[[CONST]]) {quantization_info = {input_zp = -128 : i32}} + // CHECK: %[[CONV:.+]] = "tosa.depthwise_conv2d"(%[[PAD]], %arg1, %arg2) + // CHECK-SAME: pad = [0, 0, 0, 0] + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 1, 1], quantization_info = {input_zp = -128 : i32, weight_zp = 42 : i32}, stride = [1, 1], dilation = [1, 1] } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x12x12x512xi32> + return %0 : tensor<1x12x12x512xi32> +} + +// ----- + +// CHECK-LABEL: @conv2d_padded_f32 +func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> (tensor<1x45x40x28xf32>) { + // CHECK: %[[CONST:.+]] = "tosa.const"() {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xi64>} + // CHECK: %[[PAD:.+]] = "tosa.pad"(%arg0, %[[CONST]]) + // CHECK: %[[CONV:.+]] = "tosa.conv2d"(%[[PAD]], %arg1, %arg2) + // CHECK-SAME: pad = [0, 0, 0, 0] + %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [1, 1, 1, 1], stride = [1, 1], dilation = [2, 1]} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> (tensor<1x45x40x28xf32>) + return %0 : tensor<1x45x40x28xf32> +} + +// ----- + +// CHECK-LABEL: @conv2d_quant +func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1xi8>, %arg2 : tensor<1024xi32>) -> (tensor<1x12x12x1024xi32>) { + // CHECK: %[[CONST:.+]] = "tosa.const"() {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xi64>} : () -> tensor<8xi64> + // CHECK: %[[PAD:.+]] = "tosa.pad"(%arg0, %[[CONST]]) {quantization_info = {input_zp = -22 : i32}} + // CHECK: %[[CONV:.+]] = "tosa.conv2d"(%1, %arg1, %arg2) + // CHECK-SAME: pad = [0, 0, 0, 0] + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [1, 1, 1, 1], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, stride = [1, 1]} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32> + return %0 : tensor<1x12x12x1024xi32> +}