diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/TosaToTensor/TosaToTensor.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/PatternMatch.h" @@ -27,14 +28,32 @@ LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const final { + Location loc = sliceOp.getLoc(); Value input = sliceOp.input(); SmallVector strides; + auto starts = sliceOp.start(); + auto sizes = sliceOp.size(); strides.resize(sliceOp.getType().template cast().getRank(), 1); - rewriter.replaceOpWithNewOp( - sliceOp, sliceOp.getType(), input, ValueRange({}), ValueRange({}), - ValueRange({}), sliceOp.start(), sliceOp.size(), - rewriter.getI64ArrayAttr(strides)); + SmallVector dynSizes; + for (auto i : llvm::enumerate(sizes)) { + int64_t size = i.value().cast().getInt(); + size_t index = i.index(); + if (size != ShapedType::kDynamicSize) + continue; + + auto dim = rewriter.create(loc, input, index); + auto offset = rewriter.create( + loc, + rewriter.getIndexAttr(starts[index].cast().getInt())); + dynSizes.push_back(rewriter.create(loc, dim, offset)); + } + + auto newSliceOp = rewriter.create( + sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes, + ValueRange({}), starts, sizes, rewriter.getI64ArrayAttr(strides)); + + rewriter.replaceOp(sliceOp, newSliceOp.getResult()); return success(); } }; diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp @@ -12,6 +12,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/TosaToTensor/TosaToTensor.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/PassDetail.h" @@ -31,6 +32,7 @@ RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); target.addIllegalOp(); + target.addLegalDialect(); target.addLegalDialect(); mlir::tosa::populateTosaToTensorConversionPatterns(&patterns); diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir --- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir +++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir @@ -6,3 +6,16 @@ %0 = "tosa.slice"(%arg0) {start = [2], size = [1]} : (tensor<6xf32>) -> (tensor<1xf32>) return } + +// ----- + +// CHECK-LABLE: func @slice_dyn +func.func @slice_dyn(%arg0: tensor) -> (tensor) { + // CHECK: %[[C0:.+]] = arith.constant 0 : index + // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]] + // CHECK: %[[C2:.+]] = arith.constant 2 : index + // CHECK: %[[SUB:.+]] = arith.subi %[[DIM]], %[[C2]] + // CHECK: %2 = tensor.extract_slice %arg0[2] [%[[SUB]]] [1] + %0 = "tosa.slice"(%arg0) {start = [2], size = [-1]} : (tensor) -> (tensor) + return %0 : tensor +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7897,6 +7897,7 @@ "lib/Conversion/TosaToTensor", ], deps = [ + ":ArithmeticDialect", ":ConversionPassIncGen", ":FuncDialect", ":IR",