diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1932,81 +1932,6 @@ } }; -class PadConverter : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::PadOp padOp, - PatternRewriter &rewriter) const final { - auto loc = padOp.getLoc(); - auto input = padOp.getInput1(); - auto padding = padOp.getPadding(); - - ShapedType inputTy = input.getType().cast(); - Type elementTy = inputTy.getElementType(); - int64_t rank = inputTy.getRank(); - - // Setup the default constantAttr. - - Value padConstant; - - if (padOp.getPadConst()) { - padConstant = rewriter.createOrFold( - loc, padOp.getPadConst(), ValueRange({})); - } else { - Attribute constantAttr; - if (elementTy.isa()) { - constantAttr = rewriter.getFloatAttr(elementTy, 0.0); - } else if (elementTy.isa() && !padOp.getQuantizationInfo()) { - constantAttr = rewriter.getIntegerAttr(elementTy, 0); - } else if (elementTy.isa() && padOp.getQuantizationInfo()) { - int64_t value = padOp.getQuantizationInfo()->getInputZp(); - constantAttr = rewriter.getIntegerAttr(elementTy, value); - } - if (constantAttr) - padConstant = rewriter.create(loc, constantAttr); - } - - if (!padConstant) { - return rewriter.notifyMatchFailure( - padOp, "tosa.pad was unable to determine the pad constant value."); - } - - Value lowIndex = - rewriter.create(loc, rewriter.getIndexAttr(0)); - Value highIndex = - rewriter.create(loc, rewriter.getIndexAttr(1)); - - SmallVector lowValues; - SmallVector highValues; - - lowValues.reserve(rank); - highValues.reserve(rank); - - for (int i = 0; i < rank; i++) { - Value inputIndex = rewriter.createOrFold(loc, i); - Value lowVal = rewriter.createOrFold( - loc, padding, ValueRange({inputIndex, lowIndex})); - Value highVal = rewriter.createOrFold( - loc, padding, ValueRange({inputIndex, highIndex})); - - lowVal = rewriter.createOrFold( - loc, rewriter.getIndexType(), lowVal); - highVal = rewriter.createOrFold( - loc, rewriter.getIndexType(), highVal); - - lowValues.push_back(lowVal); - highValues.push_back(highVal); - } - - auto newPadOp = rewriter.create( - loc, padOp.getType(), input, lowValues, highValues, padConstant); - - rewriter.replaceOp(padOp, newPadOp.getResult()); - return success(); - } -}; - // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic // op, producing two output buffers. // @@ -2375,7 +2300,6 @@ ArgMaxConverter, ConcatConverter, GatherConverter, - PadConverter, ReshapeConverterCollapse, ReshapeConverterExpand, ReshapeConverterCollapseExpand, diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -56,6 +56,7 @@ target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); + target.addLegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -22,7 +22,7 @@ namespace { -class SliceOpConverter : public OpRewritePattern { +class SliceConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -59,9 +59,84 @@ } }; +class PadConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::PadOp padOp, + PatternRewriter &rewriter) const final { + auto loc = padOp.getLoc(); + auto input = padOp.getInput1(); + auto padding = padOp.getPadding(); + + ShapedType inputTy = input.getType().cast(); + Type elementTy = inputTy.getElementType(); + int64_t rank = inputTy.getRank(); + + // Setup the default constantAttr. + + Value padConstant; + + if (padOp.getPadConst()) { + padConstant = rewriter.createOrFold( + loc, padOp.getPadConst(), ValueRange({})); + } else { + Attribute constantAttr; + if (elementTy.isa()) { + constantAttr = rewriter.getFloatAttr(elementTy, 0.0); + } else if (elementTy.isa() && !padOp.getQuantizationInfo()) { + constantAttr = rewriter.getIntegerAttr(elementTy, 0); + } else if (elementTy.isa() && padOp.getQuantizationInfo()) { + int64_t value = padOp.getQuantizationInfo()->getInputZp(); + constantAttr = rewriter.getIntegerAttr(elementTy, value); + } + if (constantAttr) + padConstant = rewriter.create(loc, constantAttr); + } + + if (!padConstant) { + return rewriter.notifyMatchFailure( + padOp, "tosa.pad was unable to determine the pad constant value."); + } + + Value lowIndex = + rewriter.create(loc, rewriter.getIndexAttr(0)); + Value highIndex = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + SmallVector lowValues; + SmallVector highValues; + + lowValues.reserve(rank); + highValues.reserve(rank); + + for (int i = 0; i < rank; i++) { + Value inputIndex = rewriter.createOrFold(loc, i); + Value lowVal = rewriter.createOrFold( + loc, padding, ValueRange({inputIndex, lowIndex})); + Value highVal = rewriter.createOrFold( + loc, padding, ValueRange({inputIndex, highIndex})); + + lowVal = rewriter.createOrFold( + loc, rewriter.getIndexType(), lowVal); + highVal = rewriter.createOrFold( + loc, rewriter.getIndexType(), highVal); + + lowValues.push_back(lowVal); + highValues.push_back(highVal); + } + + auto newPadOp = rewriter.create( + loc, padOp.getType(), input, lowValues, highValues, padConstant); + + rewriter.replaceOp(padOp, newPadOp.getResult()); + return success(); + } +}; + } // namespace void mlir::tosa::populateTosaToTensorConversionPatterns( RewritePatternSet *patterns) { - patterns->add(patterns->getContext()); + patterns->add(patterns->getContext()); } diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp @@ -36,6 +36,7 @@ RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); target.addIllegalOp(); + target.addIllegalOp(); target.addLegalDialect(); target.addLegalDialect(); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1301,93 +1301,6 @@ // ----- -// CHECK-LABEL: @pad_float -// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: -func.func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) { - %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> - // TODO: Output contains multiple "arith.constant 1 : index". - // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index - // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index - // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index - // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index - // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32 - // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] { - // CHECK: tensor.yield [[CST]] - // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32> - %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>) - return %1 : tensor<4x9xf32> -} - -func.func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) { - %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> - // CHECK: [[CST:%.+]] = arith.constant 0 : i32 - // CHECK: tensor.pad - // CHECK: tensor.yield [[CST]] - %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>) - return %1 : tensor<4x9xi32> -} - -func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) { - %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> - // CHECK: [[CST:%.+]] = arith.constant 42 : i32 - // CHECK: tensor.pad - // CHECK: tensor.yield [[CST]] - %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant} : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>) - return %1 : tensor<4x9xi32> -} - -// ----- - -func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) { - %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> - // TODO: Output contains multiple "arith.constant 1 : index". - // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index - // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index - // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index - // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index - // CHECK-DAG: [[CST:%.+]] = arith.constant 4.200000e+01 : f32 - // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] { - // CHECK: tensor.yield [[CST]] - // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32> - %1 = arith.constant dense<42.0> : tensor - %2 = "tosa.pad"(%arg0, %0, %1) : (tensor<1x2xf32>, tensor<2x2xi32>, tensor) -> (tensor<4x9xf32>) - return %2 : tensor<4x9xf32> -} - -// ----- - -func.func @pad_dyn_input(%arg0 : tensor) -> (tensor) { - %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> - // TODO: Output contains multiple "arith.constant 1 : index". - // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index - // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index - // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index - // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index - // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32 - // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] { - // CHECK: tensor.yield [[CST]] - // CHECK: } : tensor to tensor - %1 = "tosa.pad"(%arg0, %0) : (tensor, tensor<2x2xi32>) -> (tensor) - return %1 : tensor -} - -func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor) { - %0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32> - // TODO: Output contains multiple "arith.constant 1 : index". - // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index - // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index - // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index - // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index - // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32 - // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] { - // CHECK: tensor.yield [[CST]] - // CHECK: } : tensor<1x2xf32> to tensor - %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor) - return %1 : tensor -} - -// ----- - // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)> // CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)> diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir --- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir +++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir @@ -19,3 +19,90 @@ %0 = "tosa.slice"(%arg0) {start = [2], size = [-1]} : (tensor) -> (tensor) return %0 : tensor } + +// ----- + +// CHECK-LABEL: @pad_float +// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: +func.func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) { + %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> + // TODO: Output contains multiple "arith.constant 1 : index". + // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index + // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index + // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index + // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index + // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32 + // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] { + // CHECK: tensor.yield [[CST]] + // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32> + %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>) + return %1 : tensor<4x9xf32> +} + +func.func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) { + %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> + // CHECK: [[CST:%.+]] = arith.constant 0 : i32 + // CHECK: tensor.pad + // CHECK: tensor.yield [[CST]] + %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>) + return %1 : tensor<4x9xi32> +} + +func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) { + %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> + // CHECK: [[CST:%.+]] = arith.constant 42 : i32 + // CHECK: tensor.pad + // CHECK: tensor.yield [[CST]] + %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant} : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>) + return %1 : tensor<4x9xi32> +} + +// ----- + +func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) { + %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> + // TODO: Output contains multiple "arith.constant 1 : index". + // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index + // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index + // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index + // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index + // CHECK-DAG: [[CST:%.+]] = arith.constant 4.200000e+01 : f32 + // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] { + // CHECK: tensor.yield [[CST]] + // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32> + %1 = arith.constant dense<42.0> : tensor + %2 = "tosa.pad"(%arg0, %0, %1) : (tensor<1x2xf32>, tensor<2x2xi32>, tensor) -> (tensor<4x9xf32>) + return %2 : tensor<4x9xf32> +} + +// ----- + +func.func @pad_dyn_input(%arg0 : tensor) -> (tensor) { + %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> + // TODO: Output contains multiple "arith.constant 1 : index". + // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index + // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index + // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index + // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index + // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32 + // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] { + // CHECK: tensor.yield [[CST]] + // CHECK: } : tensor to tensor + %1 = "tosa.pad"(%arg0, %0) : (tensor, tensor<2x2xi32>) -> (tensor) + return %1 : tensor +} + +func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor) { + %0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32> + // TODO: Output contains multiple "arith.constant 1 : index". + // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index + // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index + // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index + // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index + // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32 + // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] { + // CHECK: tensor.yield [[CST]] + // CHECK: } : tensor<1x2xf32> to tensor + %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor) + return %1 : tensor +}