diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp @@ -20,6 +20,92 @@ namespace { +uint64_t permuteLinearIndex(uint64_t srcLinearIndex, + llvm::ArrayRef permValues, + llvm::ArrayRef inputShape, + llvm::ArrayRef outputShape) { + uint64_t totalCount = srcLinearIndex; + + // Convert the source linear index to it's corresponding multi-dimensional + // index into the source tensor. + SmallVector srcIndices(inputShape.size(), 0); + for (int dim = inputShape.size() - 1; dim >= 0; --dim) { + srcIndices[dim] = totalCount % inputShape[dim]; + totalCount /= inputShape[dim]; + } + + // Permute the source indices into the destination indices + SmallVector dstIndices(outputShape.size(), 0); + for (const auto &it : llvm::enumerate(permValues)) + dstIndices[it.index()] = srcIndices[it.value()]; + + // Flatten the destination indices to a linear index + uint64_t dstLinearIndex = dstIndices.front(); + for (uint64_t dim = 1; dim < outputShape.size(); ++dim) + dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; + + return dstLinearIndex; +} + +template +DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType, + ShapedType outputType, + llvm::ArrayRef permValues) { + if (inputType.getNumElements() == 0) + return DenseElementsAttr::get(outputType, llvm::ArrayRef{}); + + auto attrValues = attr.getValues(); + auto inputShape = inputType.getShape(); + auto outputShape = outputType.getShape(); + + auto initialValue = *std::begin(attrValues); + SmallVector outputValues(inputType.getNumElements(), initialValue); + + for (const auto &it : llvm::enumerate(attrValues)) { + auto dstLinearIndex = + permuteLinearIndex(it.index(), permValues, inputShape, outputShape); + outputValues[dstLinearIndex] = it.value(); + } + + return DenseElementsAttr::get(outputType, + llvm::ArrayRef(outputValues)); +} + +// A type specialized transposition of an ElementsAttr. +// This implementation tries to operate on the underlying data in its raw +// representation when possible to avoid allocating a large number of Attribute +// objects. +DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType, + ShapedType outputType, + llvm::ArrayRef permValues) { + auto baseType = inputType.getElementType(); + + // Handle possible integer types + if (auto intType = baseType.dyn_cast()) { + switch (intType.getWidth()) { + case 1: + return transposeType(attr, inputType, outputType, permValues); + case 8: + return transposeType(attr, inputType, outputType, permValues); + case 16: + return transposeType(attr, inputType, outputType, permValues); + case 32: + return transposeType(attr, inputType, outputType, permValues); + case 64: + return transposeType(attr, inputType, outputType, permValues); + default: + return transposeType(attr, inputType, outputType, permValues); + } + } + + // Handle possible float types + if (baseType.isF32()) { + return transposeType(attr, inputType, outputType, permValues); + } + + return transposeType(attr, inputType, outputType, permValues); +} + struct TosaFoldConstantTranspose : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -46,38 +132,9 @@ [](const APInt &val) { return val.getZExtValue(); })); auto inputType = op.getInput1().getType().cast(); - ArrayRef inputShape = inputType.getShape(); - int64_t numElements = inputType.getNumElements(); - - SmallVector outputValues; - outputValues.resize(numElements); - - // Transpose the input constant. Because we don't know its rank in advance, - // we need to loop over the range [0, element count) and delinearize the - // index. - auto attrValues = inputValues.getValues(); - ArrayRef outputShape = outputType.getShape(); - for (const auto &it : llvm::enumerate(attrValues)) { - SmallVector srcIndices(inputType.getRank(), 0); - int totalCount = it.index(); - for (int dim = inputType.getRank() - 1; dim >= 0; --dim) { - srcIndices[dim] = totalCount % inputShape[dim]; - totalCount /= inputShape[dim]; - } - - SmallVector dstIndices(outputType.getRank(), 0); - for (int dim = outputType.getRank() - 1; dim >= 0; --dim) - dstIndices[dim] = srcIndices[permValues[dim]]; - - uint64_t dstLinearIndex = dstIndices.front(); - for (int dim = 1; dim < outputType.getRank(); ++dim) - dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; - - outputValues[dstLinearIndex] = it.value(); - } - rewriter.replaceOpWithNewOp( - op, outputType, DenseElementsAttr::get(outputType, outputValues)); + auto resultAttr = transpose(inputValues, inputType, outputType, permValues); + rewriter.replaceOpWithNewOp(op, outputType, resultAttr); return success(); } }; diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -46,6 +46,17 @@ return %1 : tensor<3x2xf32> } +// CHECK-LABEL: @transpose_fold_2d_bool +func.func @transpose_fold_2d_bool() -> tensor<3x2xi1> { + %input = "tosa.const"() {value = dense<[[true, false, false], [false, false, true]]> : tensor<2x3xi1>} : () -> tensor<2x3xi1> + %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: %[[CST:.+]] = "tosa.const"() + // CHECK-SAME{LITERAL}: value = dense<[[true, false], [false, false], [false, true]]> : tensor<3x2xi1> + %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xi1>, tensor<2xi32>) -> tensor<3x2xi1> + // CHECK: return %[[CST]] + return %1 : tensor<3x2xi1> +} + // CHECK-LABEL: @transpose_fold_4d_int func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> { %input = "tosa.const"() {value = dense<[[