diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp @@ -20,6 +20,88 @@ namespace { +uint64_t permuteLinearIndex(uint64_t srcLinearIndex, + llvm::ArrayRef permValues, + llvm::ArrayRef inputShape, + llvm::ArrayRef outputShape) { + uint64_t totalCount = srcLinearIndex; + + // Convert the source linear index to it's corresponding multi-dimensional + // index into the source tensor. + SmallVector srcIndices(inputShape.size(), 0); + for (int dim = inputShape.size() - 1; dim >= 0; --dim) { + srcIndices[dim] = totalCount % inputShape[dim]; + totalCount /= inputShape[dim]; + } + + // Permute the source indices into the destination indices + SmallVector dstIndices(outputShape.size(), 0); + for (const auto &it : llvm::enumerate(permValues)) + dstIndices[it.index()] = srcIndices[it.value()]; + + // Flatten the destination indices to a linear index + uint64_t dstLinearIndex = dstIndices.front(); + for (uint64_t dim = 1; dim < outputShape.size(); ++dim) + dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; + + return dstLinearIndex; +} + +template +DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType, + ShapedType outputType, + llvm::ArrayRef permValues) { + auto attrValues = attr.getValues(); + auto inputShape = inputType.getShape(); + auto outputShape = outputType.getShape(); + + SmallVector outputValues; + outputValues.resize(inputType.getNumElements()); + + for (const auto &it : llvm::enumerate(attrValues)) { + auto dstLinearIndex = + permuteLinearIndex(it.index(), permValues, inputShape, outputShape); + outputValues[dstLinearIndex] = it.value(); + } + + return DenseElementsAttr::get(outputType, + llvm::ArrayRef(outputValues)); +} + +// A type specialized transposition of an ElementsAttr. +// This implementation tries to operate on the underlying data in its raw +// representation when possible to avoid allocating a large number of Attribute +// objects. +DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType, + ShapedType outputType, + llvm::ArrayRef permValues) { + auto baseType = inputType.getElementType(); + + if (auto intType = baseType.dyn_cast()) { + switch (intType.getWidth()) { + case 8: + return transposeType(attr, inputType, outputType, permValues); + case 16: + return transposeType(attr, inputType, outputType, permValues); + case 32: + return transposeType(attr, inputType, outputType, permValues); + case 64: + return transposeType(attr, inputType, outputType, permValues); + } + } + + if (auto floatType = baseType.dyn_cast()) { + switch (floatType.getWidth()) { + case 32: + return transposeType(attr, inputType, outputType, permValues); + case 64: + return transposeType(attr, inputType, outputType, permValues); + } + } + + return transposeType(attr, inputType, outputType, permValues); +} + struct TosaFoldConstantTranspose : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -46,38 +128,10 @@ [](const APInt &val) { return val.getZExtValue(); })); auto inputType = op.getInput1().getType().cast(); - ArrayRef inputShape = inputType.getShape(); - int64_t numElements = inputType.getNumElements(); - - SmallVector outputValues; - outputValues.resize(numElements); - - // Transpose the input constant. Because we don't know its rank in advance, - // we need to loop over the range [0, element count) and delinearize the - // index. - auto attrValues = inputValues.getValues(); - ArrayRef outputShape = outputType.getShape(); - for (const auto &it : llvm::enumerate(attrValues)) { - SmallVector srcIndices(inputType.getRank(), 0); - int totalCount = it.index(); - for (int dim = inputType.getRank() - 1; dim >= 0; --dim) { - srcIndices[dim] = totalCount % inputShape[dim]; - totalCount /= inputShape[dim]; - } - - SmallVector dstIndices(outputType.getRank(), 0); - for (int dim = outputType.getRank() - 1; dim >= 0; --dim) - dstIndices[dim] = srcIndices[permValues[dim]]; - - uint64_t dstLinearIndex = dstIndices.front(); - for (int dim = 1; dim < outputType.getRank(); ++dim) - dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; - - outputValues[dstLinearIndex] = it.value(); - } - rewriter.replaceOpWithNewOp( - op, outputType, DenseElementsAttr::get(outputType, outputValues)); + // Create fast paths for the common cases. + auto resultAttr = transpose(inputValues, inputType, outputType, permValues); + rewriter.replaceOpWithNewOp(op, outputType, resultAttr); return success(); } };