diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1556,68 +1556,6 @@ } }; -struct ConcatConverter : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto inputType = op.getOperand(0).getType().template cast(); - auto resultType = op.getType().dyn_cast(); - - Location loc = op.getLoc(); - int axis = op.getAxis(); - Value axisValue = rewriter.createOrFold( - loc, rewriter.getIndexAttr(axis)); - int rank = resultType.getRank(); - SmallVector offsets, sizes, strides; - sizes.reserve(rank); - strides.resize(rank, rewriter.create(loc, 1)); - offsets.resize(rank, rewriter.create(loc, 0)); - - SmallVector dynDims; - for (int i = 0; i < rank; ++i) { - sizes.push_back(rewriter.createOrFold( - loc, adaptor.getOperands()[0], i)); - if (inputType.isDynamicDim(i)) { - dynDims.push_back( - rewriter.create(loc, op.getOperand(0), i)); - } - } - - Value resultDimSize = sizes[axis]; - for (auto arg : adaptor.getOperands().drop_front()) { - auto size = rewriter.createOrFold(loc, arg, axisValue); - resultDimSize = - rewriter.createOrFold(loc, resultDimSize, size); - } - sizes[axis] = resultDimSize; - - Value emptyTensor = rewriter.create( - loc, resultType.getShape(), resultType.getElementType(), dynDims); - - auto toOpFoldResult = [](Value v) -> OpFoldResult { - auto op = v.getDefiningOp(); - if (!op) - return v; - return op.getValue(); - }; - Value result = emptyTensor; - for (auto arg : adaptor.getOperands()) { - sizes[axis] = rewriter.createOrFold(loc, arg, axisValue); - result = rewriter.createOrFold( - loc, arg, result, - llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)), - llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)), - llvm::to_vector(llvm::map_range(strides, toOpFoldResult))); - offsets[axis] = - rewriter.createOrFold(loc, offsets[axis], sizes[axis]); - } - rewriter.replaceOp(op, result); - return success(); - } -}; - class ReverseConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -2110,7 +2048,6 @@ ReduceConverter, ReduceConverter, ArgMaxConverter, - ConcatConverter, GatherConverter, RescaleConverter, ReverseConverter, diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -349,11 +349,74 @@ } }; +struct ConcatConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto inputType = op.getOperand(0).getType().template cast(); + auto resultType = op.getType().dyn_cast(); + + Location loc = op.getLoc(); + int axis = op.getAxis(); + Value axisValue = rewriter.createOrFold( + loc, rewriter.getIndexAttr(axis)); + int rank = resultType.getRank(); + SmallVector offsets, sizes, strides; + sizes.reserve(rank); + strides.resize(rank, rewriter.create(loc, 1)); + offsets.resize(rank, rewriter.create(loc, 0)); + + SmallVector dynDims; + for (int i = 0; i < rank; ++i) { + sizes.push_back(rewriter.createOrFold( + loc, adaptor.getOperands()[0], i)); + if (inputType.isDynamicDim(i)) { + dynDims.push_back( + rewriter.create(loc, op.getOperand(0), i)); + } + } + + Value resultDimSize = sizes[axis]; + for (auto arg : adaptor.getOperands().drop_front()) { + auto size = rewriter.createOrFold(loc, arg, axisValue); + resultDimSize = + rewriter.createOrFold(loc, resultDimSize, size); + } + sizes[axis] = resultDimSize; + + Value emptyTensor = rewriter.create( + loc, resultType.getShape(), resultType.getElementType(), dynDims); + + auto toOpFoldResult = [](Value v) -> OpFoldResult { + auto op = v.getDefiningOp(); + if (!op) + return v; + return op.getValue(); + }; + Value result = emptyTensor; + for (auto arg : adaptor.getOperands()) { + sizes[axis] = rewriter.createOrFold(loc, arg, axisValue); + result = rewriter.createOrFold( + loc, arg, result, + llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)), + llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)), + llvm::to_vector(llvm::map_range(strides, toOpFoldResult))); + offsets[axis] = + rewriter.createOrFold(loc, offsets[axis], sizes[axis]); + } + rewriter.replaceOp(op, result); + return success(); + } +}; + } // namespace void mlir::tosa::populateTosaToTensorConversionPatterns( RewritePatternSet *patterns) { - patterns->add(patterns->getContext()); + patterns->add( + patterns->getContext()); patterns->add(patterns->getContext(), /*benefit=*/100); patterns->add(patterns->getContext(), diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -823,79 +823,6 @@ return } -// ----- - -// CHECK-LABEL: @concat -// CHECK-SAME: %[[ARG0:.+]]: tensor<5x1xf32> -// CHECK-SAME: %[[ARG1:.+]]: tensor<6x1xf32> -func.func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () { - // CHECK: [[AXIS:%.+]] = arith.constant 0 - // CHECK: [[STRIDE:%.+]] = arith.constant 1 - // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index - // CHECK: [[IDX0:%.+]] = arith.constant 0 : index - // CHECK: [[IDX1:%.+]] = arith.constant 1 : index - // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<11x1xf32> - // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1] - // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG1]] into [[INSERT0]][5, 0] [6, 1] [1, 1] - %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x1xf32>, tensor<6x1xf32>) -> (tensor<11x1xf32>) - - // CHECK: [[AXIS:%.+]] = arith.constant 1 - // CHECK: [[STRIDE:%.+]] = arith.constant 1 - // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index - // CHECK: [[IDX0:%.+]] = arith.constant 0 : index - // CHECK: [[IDX1:%.+]] = arith.constant 1 : index - // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5x2xf32> - // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1] - // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG0]] into [[INSERT0]][0, 1] [5, 1] [1, 1] - %1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>) -> (tensor<5x2xf32>) - return -} - -// ----- - -// CHECK-LABEL: @concat_non_axis_dyn -// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: -// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]] -func.func @concat_non_axis_dyn(%arg0: tensor<5x?xf32>, %arg1: tensor<6x?xf32>) -> () { - // CHECK: %[[AXIS:.+]] = arith.constant 0 - // CHECK: %[[STRIDE:.+]] = arith.constant 1 - // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index - // CHECK: %[[IDX0:.+]] = arith.constant 0 : index - // CHECK: %[[IDX1:.+]] = arith.constant 1 : index - // CHECK: %[[SIZE:.+]] = tensor.dim %[[ARG0]], %[[IDX1]] - // CHECK: %[[IDX1_2:.+]] = arith.constant 1 : index - // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[IDX1_2]] - // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor<11x?xf32> - // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [5, %[[SIZE]]] [1, 1] - // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][5, 0] [6, %[[SIZE]]] [1, 1] - %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x?xf32>, tensor<6x?xf32>) -> (tensor<11x?xf32>) - return -} - -// ----- - -// CHECK-LABEL: @concat_axis_dyn -// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: -// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: -func.func @concat_axis_dyn(%arg0: tensor, %arg1: tensor) -> () { - // CHECK: %[[AXIS:.+]] = arith.constant 0 - // CHECK: %[[STRIDE:.+]] = arith.constant 1 - // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index - // CHECK: %[[IDX0:.+]] = arith.constant 0 : index - // CHECK: %[[SIZE:.+]] = tensor.dim %[[ARG0]], %[[IDX0]] - // CHECK: %[[IDX0_2:.+]] = arith.constant 0 : index - // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[IDX0_2]] - // CHECK: %[[IDX1:.+]] = arith.constant 1 : index - // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor - // CHECK: %[[DYN1:.+]] = tensor.dim %[[ARG0]], %[[AXIS]] - // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [%[[DYN1]], 3] [1, 1] - // CHECK: %[[SUM:.+]] = arith.addi %[[OFFSET]], %[[DYN1]] - // CHECK: %[[DYN2:.+]] = tensor.dim %[[ARG1]], %[[AXIS]] - // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][%[[SUM]], 0] [%[[DYN2]], 3] [1, 1] - %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor, tensor) -> (tensor) - return -} - // ----- // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir --- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir +++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir @@ -195,3 +195,76 @@ %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor) return %1 : tensor } + +// ----- + +// CHECK-LABEL: @concat +// CHECK-SAME: %[[ARG0:.+]]: tensor<5x1xf32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<6x1xf32> +func.func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () { + // CHECK: [[AXIS:%.+]] = arith.constant 0 + // CHECK: [[STRIDE:%.+]] = arith.constant 1 + // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index + // CHECK: [[IDX0:%.+]] = arith.constant 0 : index + // CHECK: [[IDX1:%.+]] = arith.constant 1 : index + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<11x1xf32> + // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1] + // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG1]] into [[INSERT0]][5, 0] [6, 1] [1, 1] + %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x1xf32>, tensor<6x1xf32>) -> (tensor<11x1xf32>) + + // CHECK: [[AXIS:%.+]] = arith.constant 1 + // CHECK: [[STRIDE:%.+]] = arith.constant 1 + // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index + // CHECK: [[IDX0:%.+]] = arith.constant 0 : index + // CHECK: [[IDX1:%.+]] = arith.constant 1 : index + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5x2xf32> + // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1] + // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG0]] into [[INSERT0]][0, 1] [5, 1] [1, 1] + %1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>) -> (tensor<5x2xf32>) + return +} + +// ----- + +// CHECK-LABEL: @concat_non_axis_dyn +// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]] +func.func @concat_non_axis_dyn(%arg0: tensor<5x?xf32>, %arg1: tensor<6x?xf32>) -> () { + // CHECK: %[[AXIS:.+]] = arith.constant 0 + // CHECK: %[[STRIDE:.+]] = arith.constant 1 + // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index + // CHECK: %[[IDX0:.+]] = arith.constant 0 : index + // CHECK: %[[IDX1:.+]] = arith.constant 1 : index + // CHECK: %[[SIZE:.+]] = tensor.dim %[[ARG0]], %[[IDX1]] + // CHECK: %[[IDX1_2:.+]] = arith.constant 1 : index + // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[IDX1_2]] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor<11x?xf32> + // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [5, %[[SIZE]]] [1, 1] + // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][5, 0] [6, %[[SIZE]]] [1, 1] + %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x?xf32>, tensor<6x?xf32>) -> (tensor<11x?xf32>) + return +} + +// ----- + +// CHECK-LABEL: @concat_axis_dyn +// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: +func.func @concat_axis_dyn(%arg0: tensor, %arg1: tensor) -> () { + // CHECK: %[[AXIS:.+]] = arith.constant 0 + // CHECK: %[[STRIDE:.+]] = arith.constant 1 + // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index + // CHECK: %[[IDX0:.+]] = arith.constant 0 : index + // CHECK: %[[SIZE:.+]] = tensor.dim %[[ARG0]], %[[IDX0]] + // CHECK: %[[IDX0_2:.+]] = arith.constant 0 : index + // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[IDX0_2]] + // CHECK: %[[IDX1:.+]] = arith.constant 1 : index + // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor + // CHECK: %[[DYN1:.+]] = tensor.dim %[[ARG0]], %[[AXIS]] + // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [%[[DYN1]], 3] [1, 1] + // CHECK: %[[SUM:.+]] = arith.addi %[[OFFSET]], %[[DYN1]] + // CHECK: %[[DYN2:.+]] = tensor.dim %[[ARG1]], %[[AXIS]] + // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][%[[SUM]], 0] [%[[DYN2]], 3] [1, 1] + %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor, tensor) -> (tensor) + return +}