diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -12,7 +12,9 @@ #include "mlir/Conversion/TosaToTensor/TosaToTensor.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" @@ -355,56 +357,56 @@ LogicalResult matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto inputType = cast(op.getOperand(0).getType()); auto resultType = dyn_cast(op.getType()); Location loc = op.getLoc(); int axis = op.getAxis(); Value axisValue = rewriter.createOrFold( loc, rewriter.getIndexAttr(axis)); - int rank = resultType.getRank(); - SmallVector offsets, sizes, strides; - sizes.reserve(rank); - strides.resize(rank, rewriter.create(loc, 1)); - offsets.resize(rank, rewriter.create(loc, 0)); + int64_t rank = resultType.getRank(); - SmallVector dynDims; - for (int i = 0; i < rank; ++i) { - sizes.push_back(rewriter.createOrFold( - loc, adaptor.getOperands()[0], i)); - if (inputType.isDynamicDim(i)) { - dynDims.push_back( - rewriter.create(loc, op.getOperand(0), i)); - } - } + SmallVector strides(rank, rewriter.getIndexAttr(1)); + SmallVector offsets(rank, rewriter.getIndexAttr(0)); + SmallVector sizes = tensor::createDimValues( + rewriter, op.getLoc(), adaptor.getOperands()[0]); + + // Pre-compute the offsets along the axis dimension. + // The axisOffsets will be of size rank + 1, where the last value + // will hold the total size of the tensor along the 'axis' dimension. + SmallVector axisOffsets; + axisOffsets.push_back(rewriter.getIndexAttr(0)); + axisOffsets.push_back(sizes[axis]); - Value resultDimSize = sizes[axis]; for (auto arg : adaptor.getOperands().drop_front()) { auto size = rewriter.createOrFold(loc, arg, axisValue); - resultDimSize = - rewriter.createOrFold(loc, resultDimSize, size); + auto currentOffset = + getValueOrCreateConstantIndexOp(rewriter, loc, axisOffsets.back()); + auto total = + rewriter.createOrFold(loc, currentOffset, size); + axisOffsets.push_back(getAsOpFoldResult(total)); + } + sizes[axis] = axisOffsets.back(); + + // Compute the dynamic sizes of the tensor.empty operation. + // This is based off of the specified result type of the tosa.concat + // operation, since we don't want to change the result type of the operation + // during the conversion. + SmallVector dynDims; + for (int64_t i = 0; i < rank; ++i) { + if (resultType.isDynamicDim(i)) { + dynDims.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, sizes[i])); + } } - sizes[axis] = resultDimSize; - Value emptyTensor = rewriter.create( + Value result = rewriter.create( loc, resultType.getShape(), resultType.getElementType(), dynDims); - auto toOpFoldResult = [](Value v) -> OpFoldResult { - auto op = v.getDefiningOp(); - if (!op) - return v; - return op.getValue(); - }; - Value result = emptyTensor; - for (auto arg : adaptor.getOperands()) { - sizes[axis] = rewriter.createOrFold(loc, arg, axisValue); + for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) { + auto sizes = tensor::createDimValues(rewriter, op.getLoc(), arg); + offsets[axis] = offset; result = rewriter.createOrFold( - loc, arg, result, - llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)), - llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)), - llvm::to_vector(llvm::map_range(strides, toOpFoldResult))); - offsets[axis] = - rewriter.createOrFold(loc, offsets[axis], sizes[axis]); + loc, arg, result, offsets, sizes, strides); } rewriter.replaceOp(op, result); return success(); diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir --- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir +++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir @@ -202,23 +202,13 @@ // CHECK-SAME: %[[ARG0:.+]]: tensor<5x1xf32> // CHECK-SAME: %[[ARG1:.+]]: tensor<6x1xf32> func.func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () { - // CHECK: [[AXIS:%.+]] = arith.constant 0 - // CHECK: [[STRIDE:%.+]] = arith.constant 1 - // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index - // CHECK: [[IDX0:%.+]] = arith.constant 0 : index - // CHECK: [[IDX1:%.+]] = arith.constant 1 : index - // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<11x1xf32> - // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1] - // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG1]] into [[INSERT0]][5, 0] [6, 1] [1, 1] + // CHECK-DAG: [[INIT:%.+]] = tensor.empty() : tensor<11x1xf32> + // CHECK-DAG: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1] + // CHECK-DAG: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG1]] into [[INSERT0]][5, 0] [6, 1] [1, 1] %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x1xf32>, tensor<6x1xf32>) -> (tensor<11x1xf32>) - // CHECK: [[AXIS:%.+]] = arith.constant 1 - // CHECK: [[STRIDE:%.+]] = arith.constant 1 - // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index - // CHECK: [[IDX0:%.+]] = arith.constant 0 : index - // CHECK: [[IDX1:%.+]] = arith.constant 1 : index - // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5x2xf32> - // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1] + // CHECK-DAG: [[INIT:%.+]] = tensor.empty() : tensor<5x2xf32> + // CHECK-DAG: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1] // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG0]] into [[INSERT0]][0, 1] [5, 1] [1, 1] %1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>) -> (tensor<5x2xf32>) return @@ -230,17 +220,16 @@ // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: // CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]] func.func @concat_non_axis_dyn(%arg0: tensor<5x?xf32>, %arg1: tensor<6x?xf32>) -> () { - // CHECK: %[[AXIS:.+]] = arith.constant 0 - // CHECK: %[[STRIDE:.+]] = arith.constant 1 - // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index - // CHECK: %[[IDX0:.+]] = arith.constant 0 : index - // CHECK: %[[IDX1:.+]] = arith.constant 1 : index - // CHECK: %[[SIZE:.+]] = tensor.dim %[[ARG0]], %[[IDX1]] - // CHECK: %[[IDX1_2:.+]] = arith.constant 1 : index - // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[IDX1_2]] - // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor<11x?xf32> - // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [5, %[[SIZE]]] [1, 1] - // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][5, 0] [6, %[[SIZE]]] [1, 1] + // CHECK-DAG: %[[AXIS:.+]] = arith.constant 0 + // CHECK-DAG: %[[IDX1:.+]] = arith.constant 1 + // CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[IDX1]] + // CHECK-DAG: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<11x?xf32> + // CHECK-DAG: %[[IDX1_1:.+]] = arith.constant 1 : index + // CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[IDX1_1]] + // CHECK-DAG: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [5, %[[DIM1]]] [1, 1] + // CHECK-DAG: %[[IDX1_2:.+]] = arith.constant 1 : index + // CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG1]], %[[IDX1_2]] : tensor<6x?xf32> + // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][5, 0] [6, %[[DIM2]]] [1, 1] %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x?xf32>, tensor<6x?xf32>) -> (tensor<11x?xf32>) return } @@ -251,20 +240,76 @@ // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: // CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: func.func @concat_axis_dyn(%arg0: tensor, %arg1: tensor) -> () { - // CHECK: %[[AXIS:.+]] = arith.constant 0 - // CHECK: %[[STRIDE:.+]] = arith.constant 1 - // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index - // CHECK: %[[IDX0:.+]] = arith.constant 0 : index - // CHECK: %[[SIZE:.+]] = tensor.dim %[[ARG0]], %[[IDX0]] - // CHECK: %[[IDX0_2:.+]] = arith.constant 0 : index - // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[IDX0_2]] - // CHECK: %[[IDX1:.+]] = arith.constant 1 : index - // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor - // CHECK: %[[DYN1:.+]] = tensor.dim %[[ARG0]], %[[AXIS]] - // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [%[[DYN1]], 3] [1, 1] - // CHECK: %[[SUM:.+]] = arith.addi %[[OFFSET]], %[[DYN1]] - // CHECK: %[[DYN2:.+]] = tensor.dim %[[ARG1]], %[[AXIS]] - // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][%[[SUM]], 0] [%[[DYN2]], 3] [1, 1] + // CHECK-DAG: %[[AXIS:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[IDX0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[IDX0]] : tensor + // CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG1]], %[[AXIS]] : tensor + // CHECK-DAG: %[[SUM:.+]] = arith.addi %[[DIM0]], %[[DIM1]] : index + // CHECK-DAG: %[[INIT:.+]] = tensor.empty(%[[SUM]]) : tensor + // CHECK-DAG: %[[IDX0_1:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[IDX0_1]] : tensor + // CHECK-DAG: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [%[[DIM2]], 3] [1, 1] : tensor into tensor + // CHECK-DAG: %[[IDX0_2:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[DIM3:.+]] = tensor.dim %[[ARG1]], %[[IDX0_2]] : tensor + // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][%[[DIM0]], 0] [%[[DIM3]], 3] [1, 1] : tensor into tensor + %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor, tensor) -> (tensor) return } + +// ----- + +// CHECK-LABEL: @concat_axis_dyn_mixed +// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: +// CHECK-SAME: %[[ARG2:[0-9a-zA-Z_]*]]: +func.func @concat_axis_dyn_mixed(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> () { + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[C0_0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[OFFSET0:.+]] = tensor.dim %[[ARG0]], %[[C0_0]] : tensor + // CHECK-DAG: %[[DIM1_0:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor + // CHECK-DAG: %[[OFFSET1:.+]] = arith.addi %[[OFFSET0]], %[[DIM1_0]] : index + // CHECK-DAG: %[[DIM2_2:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor + // CHECK-DAG: %[[OFFSET2:.+]] = arith.addi %[[OFFSET1]], %[[DIM2_2]] : index + // CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<5x1xf32> + // CHECK-DAG: %[[C0_3:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0_3]] : tensor + // CHECK-DAG: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [%[[DIM_4]], 1] [1, 1] : tensor into tensor<5x1xf32> + // CHECK-DAG: %[[C0_4:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[DIM_6:.+]] = tensor.dim %[[ARG1]], %[[C0_4]] : tensor + // CHECK-DAG: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][%[[OFFSET0]], 0] [%[[DIM_6]], 1] [1, 1] : tensor into tensor<5x1xf32> + // CHECK-DAG: %[[C0_8:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[DIM_9:.+]] = tensor.dim %[[ARG2]], %[[C0_8]] : tensor + // CHECK-DAG: %[[INSERT3:.+]] = tensor.insert_slice %[[ARG2]] into %[[INSERT1]][%[[OFFSET1]], 0] [%[[DIM_9]], 1] [1, 1] : tensor into tensor<5x1xf32> + + // CHECK: return + + %0 = "tosa.concat"(%arg0, %arg1, %arg2) <{axis = 0 : i64}> : (tensor, tensor, tensor) -> tensor<5x1xf32> + return +} + +// ----- + +// CHECK-LABEL: @concat_non_axis_dyn_mixed +// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: +// CHECK-SAME: %[[ARG2:[0-9a-zA-Z_]*]]: +func.func @concat_non_axis_dyn_mixed(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> () { + // CHECK-DAG: %[[UNUSED0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[UNUSED1:.+]] = tensor.dim %[[ARG0]], %[[UNUSED0]] : tensor + + // CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<5x3xf32> + // CHECK-DAG: %[[C0_0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[DIM0_0:.+]] = tensor.dim %[[ARG0]], %[[C0_0]] : tensor + // CHECK-DAG: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [%[[DIM0_0]], 1] [1, 1] : tensor into tensor<5x3xf32> + // CHECK-DAG: %[[C0_1:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[DIM1_0:.+]] = tensor.dim %[[ARG1]], %[[C0_1]] : tensor + // CHECK-DAG: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][0, 1] [%[[DIM1_0]], 1] [1, 1] : tensor into tensor<5x3xf32> + // CHECK-DAG: %[[C0_2:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[DIM2_0:.+]] = tensor.dim %[[ARG2]], %[[C0_2]] : tensor + // CHECK-DAG: %[[INSERT2:.+]] = tensor.insert_slice %[[ARG2]] into %[[INSERT1]][0, 2] [%[[DIM2_0]], 1] [1, 1] : tensor into tensor<5x3xf32> + // CHECK: return + + %0 = "tosa.concat"(%arg0, %arg1, %arg2) <{axis = 1 : i64}> : (tensor, tensor, tensor) -> tensor<5x3xf32> + return +}