diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -438,6 +439,48 @@ rewriter.replaceOpWithNewOp( reshape, resultTy, args[0], reassociationMap); + + return success(); + } +}; + +class TransposeConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::TransposeOp op, + PatternRewriter &rewriter) const final { + DenseIntElementsAttr perms; + if (!matchPattern(op.perms(), m_Constant(&perms))) { + return failure(); + } + + auto resultTy = op.getType().cast(); + if (!resultTy.hasStaticShape()) + return failure(); + + SmallVector inputExprs; + inputExprs.resize(resultTy.getRank()); + for (auto permutation : llvm::enumerate(perms.getIntValues())) { + inputExprs[permutation.value().getZExtValue()] = + rewriter.getAffineDimExpr(permutation.index()); + } + + auto initTensor = rewriter.create( + op.getLoc(), ArrayRef({}), resultTy.getShape(), + resultTy.getElementType()); + + SmallVector affineMaps = { + AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs, + rewriter.getContext()), + rewriter.getMultiDimIdentityMap(resultTy.getRank())}; + + rewriter.replaceOpWithNewOp( + op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps, + getNParallelLoopsAttrs(resultTy.getRank()), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(op.getLoc(), *args.begin()); + }); return success(); } }; @@ -478,5 +521,6 @@ PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, IdentityNConverter, - IdentityNConverter, ReshapeOpConverter>(context); + IdentityNConverter, + ReshapeOpConverter, TransposeConverter>(context); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -317,3 +317,21 @@ // CHECK: return %arg0, %arg1 return %2#0, %2#1 : tensor<1xf32>, tensor<1xi32> } + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +// CHECK-LABEL: @test_transpose +// CHECK-SAME: ([[ARG0:%.+]]: tensor<1x2x3xi32>) +func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () { + %0 = constant dense<[1, 2, 0]> : tensor<3xi32> + // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3, 1] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins([[ARG0]] : tensor<1x2x3xi32>) outs([[OUT:%.+]] : tensor<2x3x1xi32>) + // CHECK: ^bb0([[ARG1:%.+]]: i32, [[ARG2:%.+]]: i32) + // CHECK: linalg.yield [[ARG1]] + // CHECK: } + %1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3xi32>, tensor<3xi32>) -> (tensor<2x3x1xi32>) + return +}