diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1417,6 +1417,9 @@ let builders = [Tosa_PadOpQuantInfoBuilder, Tosa_ExplicitValuePadOpQuantInfoBuilder]; + + let hasCanonicalizer = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -376,6 +376,53 @@ results.insert(context); } +struct MaterializePadValue : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::PadOp op, + PatternRewriter &rewriter) const override { + if (op.pad_const()) + return failure(); + + auto input = op.input1(); + auto padding = op.padding(); + + ShapedType inputTy = input.getType().cast(); + Type elementTy = inputTy.getElementType(); + + Attribute constantAttr; + if (elementTy.isa()) + constantAttr = rewriter.getFloatAttr(elementTy, 0.0); + else if (elementTy.isa() && !op.quantization_info()) + constantAttr = rewriter.getIntegerAttr(elementTy, 0); + else if (elementTy.isa() && op.quantization_info()) { + auto value = op.quantization_info().getValue().input_zp().getValue(); + constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue()); + } + + if (!constantAttr) { + return rewriter.notifyMatchFailure( + op, + "tosa.pad to linalg lowering encountered an unknown element type"); + } + + auto denseAttr = DenseElementsAttr::get( + RankedTensorType::get({}, elementTy), constantAttr); + auto constantVal = rewriter.create( + op.getLoc(), denseAttr.getType(), denseAttr); + + rewriter.replaceOpWithNewOp( + op, op.getType(), ValueRange{input, padding, constantVal}, + op->getAttrs()); + return success(); + } +}; + +void PadOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // Operator Folders. //===----------------------------------------------------------------------===// @@ -415,6 +462,18 @@ return input1(); } +OpFoldResult PadOp::fold(ArrayRef operands) { + // If the pad is all zeros we can fold this operation away. + if (operands[1]) { + auto densePad = operands[1].cast(); + if (densePad.isSplat() && densePad.getSplatValue().isZero()) { + return input1(); + } + } + + return {}; +} + OpFoldResult SliceOp::fold(ArrayRef operands) { auto inputTy = input().getType().dyn_cast(); auto outputTy = getType().dyn_cast(); diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -66,6 +66,49 @@ return %0 : tensor } +// ---- + +// CHECK-LABEL: @pad_noop +func @pad_noop(%arg0: tensor) -> tensor { + // CHECK: return %arg0 + %0 = "tosa.const"() { value = dense<0> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %1 = "tosa.pad"(%arg0, %0) : (tensor, tensor<2x2xi32>) -> tensor + return %1 : tensor +} + +// ---- + +// CHECK-LABEL: @pad_determine_val_i32 +func @pad_determine_val_i32(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> tensor { + // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0> : tensor} + // CHECK: "tosa.pad"(%arg0, %arg1, %[[ZERO]]) + %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %1 = "tosa.pad"(%arg0, %arg1) : (tensor, tensor<2x2xi32>) -> tensor + return %1 : tensor +} + +// ---- + +// CHECK-LABEL: @pad_determine_val_f32 +func @pad_determine_val_f32(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> tensor { + // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor} + // CHECK: "tosa.pad"(%arg0, %arg1, %[[ZERO]]) + %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %1 = "tosa.pad"(%arg0, %arg1) : (tensor, tensor<2x2xi32>) -> tensor + return %1 : tensor +} + +// ---- + +// CHECK-LABEL: @pad_determine_val_quant +func @pad_determine_val_quant(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> tensor { + // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<42> : tensor} + // CHECK: "tosa.pad"(%arg0, %arg1, %[[ZERO]]) + %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %1 = "tosa.pad"(%arg0, %arg1) { quantization_info = {input_zp = 42:i32} } : (tensor, tensor<2x2xi32>) -> tensor + return %1 : tensor +} + // ----- // CHECK-LABEL: @mul_one_different_shape