diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" @@ -20,6 +21,82 @@ namespace { +template +DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType, + ShapedType outputType, + llvm::ArrayRef permValues) { + if (inputType.getNumElements() == 0) + return DenseElementsAttr::get(outputType, llvm::ArrayRef{}); + + auto attrValues = attr.getValues(); + auto inputShape = inputType.getShape(); + + // The inverted permutation map and strides of the output are used to compute + // the contribution of a given dimension to the destination linear index in + // an order-independent way. + auto outputStrides = computeStrides(outputType.getShape()); + auto invertedPermValues = invertPermutationVector(permValues); + + auto initialValue = *std::begin(attrValues); + SmallVector outputValues(inputType.getNumElements(), initialValue); + + for (const auto &it : llvm::enumerate(attrValues)) { + auto srcLinearIndex = it.index(); + + uint64_t dstLinearIndex = 0; + for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) { + // Compute the index into the current dimension of the source vector. + auto sourceIndexForDim = srcLinearIndex % inputShape[dim]; + srcLinearIndex /= inputShape[dim]; + + // Add the contribution of the current dimension to the output using the + // permutation map. + dstLinearIndex += + outputStrides[invertedPermValues[dim]] * sourceIndexForDim; + } + + outputValues[dstLinearIndex] = it.value(); + } + + return DenseElementsAttr::get(outputType, + llvm::ArrayRef(outputValues)); +} + +// A type specialized transposition of an ElementsAttr. +// This implementation tries to operate on the underlying data in its raw +// representation when possible to avoid allocating a large number of Attribute +// objects. +DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType, + ShapedType outputType, + llvm::ArrayRef permValues) { + auto baseType = inputType.getElementType(); + + // Handle possible integer types + if (auto intType = baseType.dyn_cast()) { + switch (intType.getWidth()) { + case 1: + return transposeType(attr, inputType, outputType, permValues); + case 8: + return transposeType(attr, inputType, outputType, permValues); + case 16: + return transposeType(attr, inputType, outputType, permValues); + case 32: + return transposeType(attr, inputType, outputType, permValues); + case 64: + return transposeType(attr, inputType, outputType, permValues); + default: + return transposeType(attr, inputType, outputType, permValues); + } + } + + // Handle possible float types + if (baseType.isF32()) { + return transposeType(attr, inputType, outputType, permValues); + } + + return transposeType(attr, inputType, outputType, permValues); +} + struct TosaFoldConstantTranspose : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -43,41 +120,12 @@ auto permValues = llvm::to_vector<6>(llvm::map_range( // TOSA allows both 32- and 64-bit integer tensors here. permAttr.getValues(), - [](const APInt &val) { return val.getZExtValue(); })); + [](const APInt &val) { return val.getSExtValue(); })); auto inputType = op.getInput1().getType().cast(); - ArrayRef inputShape = inputType.getShape(); - int64_t numElements = inputType.getNumElements(); - - SmallVector outputValues; - outputValues.resize(numElements); - - // Transpose the input constant. Because we don't know its rank in advance, - // we need to loop over the range [0, element count) and delinearize the - // index. - auto attrValues = inputValues.getValues(); - ArrayRef outputShape = outputType.getShape(); - for (const auto &it : llvm::enumerate(attrValues)) { - SmallVector srcIndices(inputType.getRank(), 0); - int totalCount = it.index(); - for (int dim = inputType.getRank() - 1; dim >= 0; --dim) { - srcIndices[dim] = totalCount % inputShape[dim]; - totalCount /= inputShape[dim]; - } - - SmallVector dstIndices(outputType.getRank(), 0); - for (int dim = outputType.getRank() - 1; dim >= 0; --dim) - dstIndices[dim] = srcIndices[permValues[dim]]; - - uint64_t dstLinearIndex = dstIndices.front(); - for (int dim = 1; dim < outputType.getRank(); ++dim) - dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; - - outputValues[dstLinearIndex] = it.value(); - } - rewriter.replaceOpWithNewOp( - op, outputType, DenseElementsAttr::get(outputType, outputValues)); + auto resultAttr = transpose(inputValues, inputType, outputType, permValues); + rewriter.replaceOpWithNewOp(op, outputType, resultAttr); return success(); } }; diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -46,6 +46,17 @@ return %1 : tensor<3x2xf32> } +// CHECK-LABEL: @transpose_fold_2d_bool +func.func @transpose_fold_2d_bool() -> tensor<3x2xi1> { + %input = "tosa.const"() {value = dense<[[true, false, false], [false, false, true]]> : tensor<2x3xi1>} : () -> tensor<2x3xi1> + %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: %[[CST:.+]] = "tosa.const"() + // CHECK-SAME{LITERAL}: value = dense<[[true, false], [false, false], [false, true]]> : tensor<3x2xi1> + %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xi1>, tensor<2xi32>) -> tensor<3x2xi1> + // CHECK: return %[[CST]] + return %1 : tensor<3x2xi1> +} + // CHECK-LABEL: @transpose_fold_4d_int func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> { %input = "tosa.const"() {value = dense<[[