diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1989,40 +1989,48 @@ Value input = op.input(); auto inputTy = input.getType().template cast(); auto resultTy = op.getType().template cast(); - auto rank = resultTy.getRank(); auto axis = op.axis(); - if (!inputTy.hasStaticShape()) - return rewriter.notifyMatchFailure( - op, "No initial value found for reduction operation"); + SmallVector dynDims; + for (int i = 0; i < inputTy.getRank(); i++) { + if (inputTy.isDynamicDim(i)) { + dynDims.push_back(rewriter.create(loc, input, i)); + } + } + + Value axisDimSize = rewriter.create(loc, input, axis); // First fill the output buffer with the init value. auto initTensor = rewriter .create( - loc, ArrayRef({}), inputTy.getShape(), - inputTy.getElementType()) + loc, ArrayRef({dynDims}), + inputTy.getShape(), inputTy.getElementType()) .result(); - - SmallVector inputExprs; - inputExprs.resize(resultTy.getRank()); - - for (int i = 0; i < rank; i++) - inputExprs[i] = rewriter.getAffineDimExpr(i); - - inputExprs[axis] = - rewriter.getAffineConstantExpr(inputTy.getDimSize(axis) - 1) - - inputExprs[axis]; - SmallVector affineMaps = { - AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs, - rewriter.getContext()), rewriter.getMultiDimIdentityMap(resultTy.getRank())}; rewriter.replaceOpWithNewOp( - op, resultTy, op.input(), ValueRange{initTensor}, affineMaps, + op, resultTy, ArrayRef({}), ValueRange{initTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - nestedBuilder.create(op.getLoc(), *args.begin()); + llvm::SmallVector indices; + for (unsigned int i = 0; i < inputTy.getRank(); i++) { + auto index = + rewriter.create(nestedLoc, i).getResult(); + if (i == axis) { + auto one = rewriter.create(nestedLoc, 1); + auto sizeMinusOne = + rewriter.create(nestedLoc, axisDimSize, one); + index = rewriter.create(nestedLoc, sizeMinusOne, index); + } + + indices.push_back(index); + } + + auto extract = nestedBuilder.create( + nestedLoc, input, indices); + nestedBuilder.create(op.getLoc(), + extract.getResult()); }); return success(); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -861,28 +861,62 @@ // ----- -// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (-d0 + 4, d1)> -// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 3)> +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: @reverse func @reverse(%arg0: tensor<5x4xi32>) -> () { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 4] - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x4xi32>) outs([[INIT]] : tensor<5x4xi32>) { - // CHECK: ^bb0(%arg1: i32, %arg2: i32): - // CHECK: linalg.yield %arg1 : i32 + // CHECK: %[[C0:.+]] = constant 0 + // CHECK: %[[RDIM:.+]] = tensor.dim %arg0, %[[C0]] + // CHECK: %[[INIT:.+]] = linalg.init_tensor [5, 4] + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]]], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<5x4xi32>) + // CHECK-DAG: %[[I0:.+]] = linalg.index 0 + // CHECK-DAG: %[[I1:.+]] = linalg.index 1 + // CHECK-DAG: %[[SUB1:.+]] = constant 1 + // CHECK-DAG: %[[RDIM_MINUS_C1:.+]] = subi %[[RDIM]], %[[SUB1]] + // CHECK-DAG: %[[READ_DIM:.+]] = subi %[[RDIM_MINUS_C1]], %[[I0]] + // CHECK-DAG: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[READ_DIM]], %[[I1]]] : tensor<5x4xi32> + // CHECK: linalg.yield %[[EXTRACT]] %0 = "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<5x4xi32> - // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 4] - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x4xi32>) outs([[INIT]] : tensor<5x4xi32>) { - // CHECK: ^bb0(%arg1: i32, %arg2: i32): - // CHECK: linalg.yield %arg1 : i32 + // CHECK: %[[C1:.+]] = constant 1 + // CHECK: %[[RDIM:.+]] = tensor.dim %arg0, %[[C1]] + // CHECK: %[[INIT:.+]] = linalg.init_tensor [5, 4] + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]]], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<5x4xi32>) + // CHECK-DAG: %[[I0:.+]] = linalg.index 0 + // CHECK-DAG: %[[I1:.+]] = linalg.index 1 + // CHECK-DAG: %[[SUB1:.+]] = constant 1 + // CHECK-DAG: %[[RDIM_MINUS_C1:.+]] = subi %[[RDIM]], %[[SUB1]] + // CHECK-DAG: %[[READ_DIM:.+]] = subi %[[RDIM_MINUS_C1]], %[[I1]] + // CHECK-DAG: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[I0]], %[[READ_DIM]]] : tensor<5x4xi32> + // CHECK: linalg.yield %[[EXTRACT]] %1 = "tosa.reverse"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x4xi32> return } // ----- +// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: @reverse_dyn +func @reverse_dyn(%arg0: tensor) -> () { + // CHECK: %[[C0_1:.+]] = constant 0 + // CHECK: %[[D0_1:.+]] = tensor.dim %arg0, %[[C0_1]] + // CHECK: %[[C0_2:.+]] = constant 0 + // CHECK: %[[D0_2:.+]] = tensor.dim %arg0, %[[C0_2]] + // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[D0_1]]] + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]]], iterator_types = ["parallel"]} outs(%[[INIT]] : tensor) + // CHECK-DAG: %[[I0:.+]] = linalg.index 0 + // CHECK-DAG: %[[SUB1:.+]] = constant 1 + // CHECK-DAG: %[[RDIM_MINUS_C1:.+]] = subi %[[D0_2]], %[[SUB1]] + // CHECK-DAG: %[[READ_DIM:.+]] = subi %[[RDIM_MINUS_C1]], %[[I0]] + // CHECK-DAG: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[READ_DIM]]] : tensor + // CHECK: linalg.yield %[[EXTRACT]] + %0 = "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor) -> tensor + return +} + +// ----- + // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>