diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -657,6 +657,52 @@ } }; +struct ConcatOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tosa::ConcatOp op, ArrayRef args, + ConversionPatternRewriter &rewriter) const override { + auto resultType = op.getResult().getType().dyn_cast(); + if (!resultType || !resultType.hasStaticShape()) { + return rewriter.notifyMatchFailure(op, + "expected static shape for output"); + } + + Location loc = op.getLoc(); + int axis = op.axis(); + Value axisValue = + rewriter.create(loc, rewriter.getIndexAttr(axis)); + int rank = resultType.getRank(); + SmallVector offsets, sizes, strides; + strides.resize(rank, rewriter.create(loc, 1)); + offsets.resize(rank, rewriter.create(loc, 0)); + + for (int i = 0; i < rank; ++i) { + sizes.push_back(rewriter.create(loc, args[0], i)); + } + + Value resultDimSize = sizes[axis]; + for (auto arg : args.drop_front()) { + auto size = rewriter.create(loc, arg, axisValue); + resultDimSize = rewriter.create(loc, resultDimSize, size); + } + sizes[axis] = resultDimSize; + + Value result = rewriter.create( + loc, resultType.getShape(), resultType.getElementType()); + + for (auto arg : args) { + sizes[axis] = rewriter.create(loc, arg, axisValue); + result = rewriter.create(loc, arg, result, offsets, + sizes, strides); + offsets[axis] = rewriter.create(loc, offsets[axis], sizes[axis]); + } + rewriter.replaceOp(op, result); + return success(); + } +}; + } // namespace void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( @@ -680,6 +726,6 @@ IdentityNConverter, IdentityNConverter, ReduceConverter, ReduceConverter, ReduceConverter, - ReduceConverter, ReshapeOpConverter, - TransposeConverter>(context); + ReduceConverter, ConcatOpConversion, + ReshapeOpConverter, TransposeConverter>(context); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -433,3 +433,43 @@ %4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32> return } + +// ----- + +// CHECK-LABEL: @concat +func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () { + // CHECK: [[AXIS:%.+]] = constant 0 + // CHECK: [[STRIDE:%.+]] = constant 1 + // CHECK: [[OFFSET:%.+]] = constant 0 : index + // CHECK: [[IDX0:%.+]] = constant 0 : index + // CHECK: [[ARG0_DIM0:%.+]] = dim %arg0, [[IDX0]] + // CHECK: [[IDX1:%.+]] = constant 1 : index + // CHECK: [[ARG0_DIM1:%.+]] = dim %arg0, [[IDX1]] + // CHECK: [[ARG1_AXIS:%.+]] = dim %arg1, [[AXIS]] + // CHECK: [[RESULT_AXIS:%.+]] = addi [[ARG0_DIM0]], [[ARG1_AXIS]] + // CHECK: [[INIT:%.+]] = linalg.init_tensor [11, 1] + // CHECK: [[ARG0_DIM0:%.+]] = dim %arg0, [[AXIS]] + // CHECK: [[INSERT0:%.+]] = subtensor_insert %arg0 into [[INIT]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]] + // CHECK: [[NEW_OFFSET:%.+]] = addi [[OFFSET]], [[ARG0_DIM0]] + // CHECK: [[ARG1_DIM0:%.+]] = dim %arg1, [[AXIS]] + // CHECK: [[INSERT1:%.+]] = subtensor_insert %arg1 into [[INSERT0]]{{\[}}[[NEW_OFFSET]], [[OFFSET]]] {{\[}}[[ARG1_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]] + %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x1xf32>, tensor<6x1xf32>) -> (tensor<11x1xf32>) + + // CHECK: [[AXIS:%.+]] = constant 1 + // CHECK: [[STRIDE:%.+]] = constant 1 + // CHECK: [[OFFSET:%.+]] = constant 0 : index + // CHECK: [[IDX0:%.+]] = constant 0 : index + // CHECK: [[ARG0_DIM0:%.+]] = dim %arg0, [[IDX0]] + // CHECK: [[IDX1:%.+]] = constant 1 : index + // CHECK: [[ARG0_DIM1:%.+]] = dim %arg0, [[IDX1]] + // CHECK: [[ARG1_AXIS:%.+]] = dim %arg0, [[AXIS]] + // CHECK: [[RESULT_AXIS:%.+]] = addi [[ARG0_DIM1]], [[ARG1_AXIS]] + // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2] + // CHECK: [[ARG0_DIM1:%.+]] = dim %arg0, [[AXIS]] + // CHECK: [[INSERT0:%.+]] = subtensor_insert %arg0 into [[INIT]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]] + // CHECK: [[NEW_OFFSET:%.+]] = addi [[OFFSET]], [[ARG0_DIM1]] + // CHECK: [[ARG1_DIM1:%.+]] = dim %arg0, [[AXIS]] + // CHECK: [[INSERT1:%.+]] = subtensor_insert %arg0 into [[INSERT0]]{{\[}}[[OFFSET]], [[NEW_OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG1_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]] + %1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>) -> (tensor<5x2xf32>) + return +}