diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1681,11 +1681,8 @@ LogicalResult matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + auto inputType = op.getOperand(0).getType().template cast(); auto resultType = op.getType().dyn_cast(); - if (!resultType || !resultType.hasStaticShape()) { - return rewriter.notifyMatchFailure(op, - "expected static shaped tensor type"); - } Location loc = op.getLoc(); int axis = op.axis(); @@ -1697,9 +1694,14 @@ strides.resize(rank, rewriter.create(loc, 1)); offsets.resize(rank, rewriter.create(loc, 0)); + SmallVector dynDims; for (int i = 0; i < rank; ++i) { sizes.push_back(rewriter.createOrFold( loc, adaptor.getOperands()[0], i)); + if (inputType.isDynamicDim(i)) { + dynDims.push_back( + rewriter.create(loc, op.getOperand(0), i)); + } } Value resultDimSize = sizes[axis]; @@ -1711,7 +1713,7 @@ sizes[axis] = resultDimSize; Value init = rewriter.create( - loc, resultType.getShape(), resultType.getElementType()); + loc, dynDims, resultType.getShape(), resultType.getElementType()); Value zeroVal = rewriter.createOrFold( loc, rewriter.getZeroAttr(resultType.getElementType())); @@ -1815,9 +1817,6 @@ auto elementTy = inputTy.getElementType(); int64_t rank = inputTy.getRank(); - if (!inputTy.hasStaticShape() || !resultTy.hasStaticShape()) - return failure(); - SmallVector multiples; getValuesFromIntArrayAttribute(op.multiples(), multiples); @@ -1828,8 +1827,15 @@ genericShape.push_back(inputShape[i]); } + SmallVector dynDims; + for (int i = 0; i < inputTy.getRank(); i++) { + if (inputTy.isDynamicDim(i) || multiples[i] == -1) { + dynDims.push_back(rewriter.create(loc, input, i)); + } + } + auto initTensor = rewriter.create( - op.getLoc(), ArrayRef({}), genericShape, elementTy); + op.getLoc(), dynDims, genericShape, elementTy); // We needs to map the input shape to the non-broadcasted dimensions. SmallVector dimExprs; @@ -1870,16 +1876,9 @@ auto padding = padOp.padding(); ShapedType inputTy = input.getType().cast(); - ShapedType paddingTy = padding.getType().cast(); Type elementTy = inputTy.getElementType(); int64_t rank = inputTy.getRank(); - if (!inputTy.hasStaticShape() || !paddingTy.hasStaticShape()) { - return rewriter.notifyMatchFailure( - padOp, - "Pad converter requires static shaped input / padding values."); - } - // Setup the default constantAttr. Value padConstant; @@ -1970,21 +1969,23 @@ int axis = argmaxOp.axis(); auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy); - if (!inputTy.hasStaticShape()) - return rewriter.notifyMatchFailure( - argmaxOp, - "tosa.arg_max to linalg.* requires statically shaped input"); - if (!outElementTy.isa()) return rewriter.notifyMatchFailure( argmaxOp, "tosa.arg_max to linalg.* requires integer-like result type"); + SmallVector dynDims; + for (int i = 0; i < inputTy.getRank(); i++) { + if (inputTy.isDynamicDim(i) && i != axis) { + dynDims.push_back(rewriter.create(loc, input, i)); + } + } + // First fill the output buffer for the index. auto initTensorIdx = rewriter - .create(loc, ArrayRef({}), - resultTy.getShape(), outElementTy) + .create(loc, dynDims, resultTy.getShape(), + outElementTy) .result(); auto fillValueIdx = rewriter.create( loc, rewriter.getIntegerAttr(outElementTy, 0)); @@ -1993,11 +1994,10 @@ .result(); // Second fill the output buffer for the running max. - auto initTensorMax = - rewriter - .create(loc, ArrayRef({}), - resultTy.getShape(), inElementTy) - .result(); + auto initTensorMax = rewriter + .create( + loc, dynDims, resultTy.getShape(), inElementTy) + .result(); auto fillValueMaxAttr = createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter); @@ -2138,18 +2138,22 @@ auto tableTy = table.getType().cast(); auto resultTy = op.getType().cast(); - if (!inputTy.hasStaticShape()) - return rewriter.notifyMatchFailure( - op, "require input type to have static shape"); - auto inputElementTy = inputTy.getElementType(); auto tableElementTy = tableTy.getElementType(); auto resultElementTy = resultTy.getElementType(); + SmallVector dynDims; + for (int i = 0; i < resultTy.getRank(); ++i) { + if (inputTy.isDynamicDim(i)) { + dynDims.push_back( + rewriter.create(loc, op.getOperand(0), i)); + } + } + auto initTensor = rewriter - .create(loc, ArrayRef{}, - resultTy.getShape(), resultElementTy) + .create(loc, dynDims, resultTy.getShape(), + resultElementTy) .result(); SmallVector affineMaps = { diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -910,6 +910,50 @@ // ----- +// CHECK-LABEL: @concat_non_axis_dyn +func @concat_non_axis_dyn(%arg0: tensor<5x?xf32>, %arg1: tensor<6x?xf32>) -> () { + // CHECK: %[[AXIS:.+]] = arith.constant 0 + // CHECK: %[[STRIDE:.+]] = arith.constant 1 + // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index + // CHECK: %[[IDX0:.+]] = arith.constant 0 : index + // CHECK: %[[IDX1:.+]] = arith.constant 1 : index + // CHECK: %[[SIZE:.+]] = tensor.dim %arg0, %[[IDX1]] + // CHECK: %[[IDX1_2:.+]] = arith.constant 1 : index + // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[IDX1_2]] + // CHECK: %[[INIT:.+]] = linalg.init_tensor [11, %[[DYN]]] + // CHECK: %[[CST:.+]] = arith.constant 0.0 + // CHECK: %[[FILL:.+]] = linalg.fill(%[[CST]], %[[INIT]]) + // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %arg0 into %[[FILL]][0, 0] [5, %[[SIZE]]] [1, 1] + // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %arg1 into %[[INSERT0]][5, 0] [6, %[[SIZE]]] [1, 1] + %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x?xf32>, tensor<6x?xf32>) -> (tensor<11x?xf32>) + return +} + +// ----- + +// CHECK-LABEL: @concat_axis_dyn +func @concat_axis_dyn(%arg0: tensor, %arg1: tensor) -> () { + // CHECK: %[[AXIS:.+]] = arith.constant 0 + // CHECK: %[[STRIDE:.+]] = arith.constant 1 + // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index + // CHECK: %[[IDX0:.+]] = arith.constant 0 : index + // CHECK: %[[SIZE:.+]] = tensor.dim %arg0, %[[IDX0]] + // CHECK: %[[IDX0_2:.+]] = arith.constant 0 : index + // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[IDX0_2]] + // CHECK: %[[IDX1:.+]] = arith.constant 1 : index + // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]], 3] + // CHECK: %[[CST:.+]] = arith.constant 0.0 + // CHECK: %[[FILL:.+]] = linalg.fill(%[[CST]], %[[INIT]]) + // CHECK: %[[DYN1:.+]] = tensor.dim %arg0, %[[AXIS]] + // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %arg0 into %[[FILL]][0, 0] [%[[DYN1]], 3] [1, 1] + // CHECK: %[[SUM:.+]] = arith.addi %[[OFFSET]], %[[DYN1]] + // CHECK: %[[DYN2:.+]] = tensor.dim %arg1, %[[AXIS]] + // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %arg1 into %[[INSERT0]][%[[SUM]], 0] [%[[DYN2]], 3] [1, 1] + %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor, tensor) -> (tensor) + return +} + +// ----- // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: @rescale_i8 @@ -1150,6 +1194,44 @@ // ----- +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +// CHECK-LABEL: @tile_dyn_input +func @tile_dyn_input(%arg0 : tensor) -> () { + // CHECK: %[[CST0:.+]] = arith.constant 0 + // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST0]] : tensor + // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, %[[DYN]], 1, 3] + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor) outs(%[[INIT]] : tensor<2x?x1x3xi8>) + // CHECK: linalg.yield %arg1 : i8 + // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1, 2, 3]] + // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1]] + %0 = "tosa.tile"(%arg0) {multiples = [2, 1]} : (tensor) -> (tensor) + + return +} + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +// CHECK-LABEL: @tile_dyn_multiples +func @tile_dyn_multiples(%arg0 : tensor<2x3xi8>) -> () { + // CHECK: %[[CST1:.+]] = arith.constant 1 + // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST1]] : tensor<2x3xi8> + // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 2, %[[DYN]], 3] + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs(%[[INIT]] : tensor<2x2x?x3xi8>) + // CHECK: linalg.yield %arg1 : i8 + // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1, 2, 3]] + // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1]] + %0 = "tosa.tile"(%arg0) {multiples = [2, -1]} : (tensor<2x3xi8>) -> (tensor<2x?xi8>) + + return +} + +// ----- + func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) { %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> // TODO: Output contains multiple "arith.constant 1 : index". @@ -1205,6 +1287,40 @@ // ----- +func @pad_dyn_input(%arg0 : tensor) -> (tensor) { + %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> + // TODO: Output contains multiple "arith.constant 1 : index". + // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index + // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index + // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index + // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index + // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32 + // CHECK: tensor.pad %arg0 low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] { + // CHECK: ^bb0(%arg1: index, %arg2: index): + // CHECK: tensor.yield [[CST]] + // CHECK: } : tensor to tensor + %1 = "tosa.pad"(%arg0, %0) : (tensor, tensor<2x2xi32>) -> (tensor) + return %1 : tensor +} + +func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor) { + %0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32> + // TODO: Output contains multiple "arith.constant 1 : index". + // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index + // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index + // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index + // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index + // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32 + // CHECK: tensor.pad %arg0 low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] { + // CHECK: ^bb0(%arg1: index, %arg2: index): + // CHECK: tensor.yield [[CST]] + // CHECK: } : tensor<1x2xf32> to tensor + %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor) + return %1 : tensor +} + +// ----- + // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)> // CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)> @@ -1256,6 +1372,54 @@ // ----- +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)> + +func @argmax_dyn_non_axis(%arg0 : tensor<3x?xi32>) -> () { + // CHECK: %[[CST1:.+]] = arith.constant 1 + // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST1]] + // CHECK: %[[IDX_INIT:.+]] = linalg.init_tensor [%[[DYN]]] + // CHECK: %[[IDX_MIN:.+]] = arith.constant 0 : i32 + // CHECK: %[[IDX_FILL:.+]] = linalg.fill(%[[IDX_MIN]], %[[IDX_INIT]]) + // CHECK: %[[VAL_INIT:.+]] = linalg.init_tensor [%[[DYN]]] + // CHECK: %[[VAL_MIN:.+]] = arith.constant -2147483648 + // CHECK: %[[VAL_FILL:.+]] = linalg.fill(%[[VAL_MIN]], %[[VAL_INIT]]) + // CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins(%arg0 : tensor<3x?xi32>) outs(%[[IDX_FILL]], %[[VAL_FILL]] : tensor, tensor) + // CHECK: %[[IDX:.+]] = linalg.index 0 + // CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX]] + // CHECK: %[[CMP:.+]] = arith.cmpi sgt, %arg1, %arg3 + // CHECK: %[[SELECT_VAL:.+]] = select %[[CMP]], %arg1, %arg3 + // CHECK: %[[SELECT_IDX:.+]] = select %[[CMP]], %[[CAST]], %arg2 + // CHECK: linalg.yield %[[SELECT_IDX]], %[[SELECT_VAL]] + %0 = "tosa.argmax"(%arg0) { axis = 0 : i64} : (tensor<3x?xi32>) -> (tensor) + return +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)> + +func @argmax_dyn_axis(%arg0 : tensor<3x?xi32>) -> () { + // CHECK: %[[IDX_INIT:.+]] = linalg.init_tensor [3] + // CHECK: %[[IDX_MIN:.+]] = arith.constant 0 : i32 + // CHECK: %[[IDX_FILL:.+]] = linalg.fill(%[[IDX_MIN]], %[[IDX_INIT]]) + // CHECK: %[[VAL_INIT:.+]] = linalg.init_tensor [3] + // CHECK: %[[VAL_MIN:.+]] = arith.constant -2147483648 + // CHECK: %[[VAL_FILL:.+]] = linalg.fill(%[[VAL_MIN]], %[[VAL_INIT]]) + // CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<3x?xi32>) outs(%[[IDX_FILL]], %[[VAL_FILL]] : tensor<3xi32>, tensor<3xi32>) + // CHECK: %[[IDX:.+]] = linalg.index 1 + // CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX]] + // CHECK: %[[CMP:.+]] = arith.cmpi sgt, %arg1, %arg3 + // CHECK: %[[SELECT_VAL:.+]] = select %[[CMP]], %arg1, %arg3 + // CHECK: %[[SELECT_IDX:.+]] = select %[[CMP]], %[[CAST]], %arg2 + // CHECK: linalg.yield %[[SELECT_IDX]], %[[SELECT_VAL]] + %0 = "tosa.argmax"(%arg0) { axis = 1 : i64} : (tensor<3x?xi32>) -> (tensor<3xi32>) + return +} + +// ----- + // CHECK-LABEL: @gather_float func @gather_float(%arg0: tensor<2x3x2xf32>, %arg1: tensor<2x3xi32>) -> () { // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 2] @@ -1349,6 +1513,40 @@ // ----- +// CHECK-LABEL: @table8_dyn +func @table8_dyn(%arg0: tensor, %arg1: tensor<512xi8>) -> () { + // CHECK: %[[CST0:.+]] = arith.constant 0 + // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST0]] + // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]]] + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor) outs(%[[INIT]] : tensor) + // CHECK: ^bb0(%[[ARG_IN:.+]]: i8, %[[ARG_INIT:.+]]: i8) + // CHECK: %[[CAST:.+]] = arith.index_cast %[[ARG_IN]] + // CHECK: %[[OFFSET:.+]] = arith.constant 128 + // CHECK: %[[ADD:.+]] = arith.addi %[[CAST]], %[[OFFSET]] + // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg1[%[[ADD]]] + // CHECK: linalg.yield %[[EXTRACT]] + %0 = "tosa.table"(%arg0, %arg1) : (tensor, tensor<512xi8>) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @table8_dyn_table +func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor) -> () { + // CHECK: %[[INIT:.+]] = linalg.init_tensor [6] + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<6xi8>) outs(%[[INIT]] : tensor<6xi8>) + // CHECK: ^bb0(%[[ARG_IN:.+]]: i8, %[[ARG_INIT:.+]]: i8) + // CHECK: %[[CAST:.+]] = arith.index_cast %[[ARG_IN]] + // CHECK: %[[OFFSET:.+]] = arith.constant 128 + // CHECK: %[[ADD:.+]] = arith.addi %[[CAST]], %[[OFFSET]] + // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg1[%[[ADD]]] + // CHECK: linalg.yield %[[EXTRACT]] + %0 = "tosa.table"(%arg0, %arg1) : (tensor<6xi8>, tensor) -> (tensor<6xi8>) + return +} + +// ----- + // CHECK-LABEL: @resize_nearest func @resize_nearest(%input: tensor<1x2x2x1xf32>) -> () { // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1]