diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp --- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp @@ -32,9 +32,28 @@ } }; +class SliceOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, + PatternRewriter &rewriter) const final { + Value input = sliceOp.input(); + SmallVector strides; + strides.resize(sliceOp.getType().template cast().getRank(), 1); + + rewriter.replaceOpWithNewOp( + sliceOp, sliceOp.getType(), input, ValueRange({}), ValueRange({}), + ValueRange({}), sliceOp.start(), sliceOp.size(), + rewriter.getI64ArrayAttr(strides)); + + return success(); + } +}; + } // namespace void mlir::tosa::populateTosaToStandardConversionPatterns( MLIRContext *context, OwningRewritePatternList *patterns) { - patterns->insert(context); + patterns->insert(context); } diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp --- a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp @@ -32,7 +32,8 @@ OwningRewritePatternList patterns; ConversionTarget target(getContext()); target.addIllegalOp(); - target.addLegalOp(); + target.addIllegalOp(); + target.addLegalDialect(); auto *op = getOperation(); mlir::tosa::populateTosaToStandardConversionPatterns(op->getContext(), diff --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir --- a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir +++ b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir @@ -8,3 +8,11 @@ // CHECK: return [[C3]] return %0 : tensor } + +// ---- + +func @slice(%arg0: tensor<6xf32>) ->() { + // CHECK: [[SLICE:%.+]] = subtensor %arg0[2] [1] [1] + %0 = "tosa.slice"(%arg0) {start = [2], size = [1]} : (tensor<6xf32>) -> (tensor<1xf32>) + return +}