diff --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt --- a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt +++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt @@ -15,6 +15,7 @@ MLIRLinalgUtils MLIRMath MLIRPass + MLIRTensor MLIRTosa MLIRTosaTransforms MLIRSupport diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" @@ -657,6 +658,83 @@ } }; +class PadConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::PadOp padOp, + PatternRewriter &rewriter) const final { + auto loc = padOp.getLoc(); + auto input = padOp.input1(); + auto padding = padOp.padding(); + + ShapedType inputTy = input.getType().cast(); + ShapedType paddingTy = padding.getType().cast(); + Type elementTy = inputTy.getElementType(); + + int64_t rank = inputTy.getRank(); + if (rank != paddingTy.getDimSize(0)) { + return rewriter.notifyMatchFailure( + padOp, "Input rank does not match padding dim-0."); + } + + if (!inputTy.hasStaticShape() || !paddingTy.hasStaticShape()) { + return rewriter.notifyMatchFailure( + padOp, + "Pad converter requires static shaped input / padding values."); + } + + Value lowIndex = rewriter.create(loc, rewriter.getIndexAttr(0)); + Value highIndex = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + SmallVector lowValues; + SmallVector highValues; + + lowValues.reserve(rank); + highValues.reserve(rank); + + for (int i = 0; i < rank; i++) { + Value inputIndex = rewriter.createOrFold(loc, i); + Value lowVal = rewriter.createOrFold( + loc, padding, ValueRange({inputIndex, lowIndex})); + Value highVal = rewriter.createOrFold( + loc, padding, ValueRange({inputIndex, highIndex})); + + lowVal = rewriter.createOrFold(loc, rewriter.getIndexType(), + lowVal); + highVal = rewriter.createOrFold(loc, rewriter.getIndexType(), + highVal); + + lowValues.push_back(lowVal); + highValues.push_back(highVal); + } + + Value constant; + if (elementTy.isa()) + constant = rewriter.create( + loc, rewriter.getFloatAttr(elementTy, 0.0)); + + if (elementTy.isa() && padOp.quantization_info()) + constant = rewriter.create(loc, + padOp.quantization_info() + .getValue() + .input_zp() + .getValue() + .getZExtValue(), + elementTy); + + if (elementTy.isa() && !padOp.quantization_info()) + constant = rewriter.create(loc, 0, elementTy); + + auto newPadOp = linalg::PadTensorOp::createPadScalarOp( + padOp.getType(), input, constant, lowValues, highValues, loc, rewriter); + + rewriter.replaceOp(padOp, newPadOp.getResult()); + return success(); + } +}; + } // namespace void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( @@ -680,6 +758,6 @@ IdentityNConverter, IdentityNConverter, ReduceConverter, ReduceConverter, ReduceConverter, - ReduceConverter, ReshapeOpConverter, + ReduceConverter, PadConverter, ReshapeOpConverter, TransposeConverter>(context); } diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/PassDetail.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" @@ -31,14 +32,15 @@ : public TosaToLinalgOnTensorsBase { public: void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); + registry.insert(); } void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); target.addIllegalDialect(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -433,3 +433,46 @@ %4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32> return } + +// ----- + +func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) { + %0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> + // CHECK: [[INDEX0:%.+]] = constant 0 : index + // CHECK: [[INDEX1:%.+]] = constant 1 : index + // CHECK: [[ROW0:%.+]] = constant 0 : index + // CHECK: [[LOW0:%.+]] = tensor.extract %cst{{\[}}[[ROW0]], [[INDEX0]]] + // CHECK: [[HIGH0:%.+]] = tensor.extract %cst{{\[}}[[ROW0]], [[INDEX1]]] + // CHECK: [[LOW0_IDX:%.+]] = index_cast %0 + // CHECK: [[HIGH0_IDX:%.+]] = index_cast %1 + // CHECK: [[ROW1:%.+]] = constant 1 : index + // CHECK: [[LOW1:%.+]] = tensor.extract %cst{{\[}}%c1_1, %c0] + // CHECK: [[HIGH1:%.+]] = tensor.extract %cst{{\[}}%c1_1, %c1] + // CHECK: [[LOW1_IDX:%.+]] = index_cast [[LOW1]] + // CHECK: [[HIGH1_IDX:%.+]] = index_cast [[HIGH1]] + // CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32 + // CHECK: %8 = linalg.pad_tensor %arg0 low{{\[}}[[LOW0_IDX]], [[LOW1_IDX]]] high{{\[}}[[HIGH0_IDX]], [[HIGH1_IDX]]] { + // CHECK: ^bb0(%arg1: index, %arg2: index): // no predecessors + // CHECK: linalg.yield [[CST]] + // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32> + %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>) + return %1 : tensor<4x9xf32> +} + +func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) { + %0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> + // CHECK: [[CST:%.+]] = constant 0 : i32 + // CHECK: linalg.pad_tensor + // CHECK: linalg.yield [[CST]] + %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>) + return %1 : tensor<4x9xi32> +} + +func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) { + %0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> + // CHECK: [[CST:%.+]] = constant 42 : i32 + // CHECK: linalg.pad_tensor + // CHECK: linalg.yield [[CST]] + %1 = "tosa.pad"(%arg0, %0) { quantization_info = { input_zp = 42 : i32}} : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>) + return %1 : tensor<4x9xi32> +}