diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -591,23 +591,23 @@ const vector::UnrollVectorOptions options; }; -struct UnrollTranposePattern : public OpRewritePattern { - UnrollTranposePattern(MLIRContext *context, - const vector::UnrollVectorOptions &options, - PatternBenefit benefit = 1) +struct UnrollTransposePattern : public OpRewritePattern { + UnrollTransposePattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), options(options) {} - LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp, + LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, PatternRewriter &rewriter) const override { - if (tranposeOp.getResultType().getRank() == 0) + if (transposeOp.getResultType().getRank() == 0) return failure(); - auto targetShape = getTargetShape(options, tranposeOp); + auto targetShape = getTargetShape(options, transposeOp); if (!targetShape) return failure(); - auto originalVectorType = tranposeOp.getResultType(); + auto originalVectorType = transposeOp.getResultType(); SmallVector strides(targetShape->size(), 1); - Location loc = tranposeOp.getLoc(); + Location loc = transposeOp.getLoc(); ArrayRef originalSize = originalVectorType.getShape(); SmallVector ratio = *computeShapeRatio(originalSize, *targetShape); int64_t sliceCount = computeMaxLinearIndex(ratio); @@ -615,7 +615,7 @@ Value result = rewriter.create( loc, originalVectorType, rewriter.getZeroAttr(originalVectorType)); SmallVector permutation; - tranposeOp.getTransp(permutation); + transposeOp.getTransp(permutation); // Stride of the ratios, this gives us the offsets of sliceCount in a basis // of multiples of the targetShape. @@ -631,13 +631,14 @@ permutedShape[indices.value()] = (*targetShape)[indices.index()]; } Value slicedOperand = rewriter.create( - loc, tranposeOp.getVector(), permutedOffsets, permutedShape, strides); - Value tranposedSlice = + loc, transposeOp.getVector(), permutedOffsets, permutedShape, + strides); + Value transposedSlice = rewriter.create(loc, slicedOperand, permutation); result = rewriter.create( - loc, tranposedSlice, result, elementOffsets, strides); + loc, transposedSlice, result, elementOffsets, strides); } - rewriter.replaceOp(tranposeOp, result); + rewriter.replaceOp(transposeOp, result); return success(); } @@ -653,5 +654,5 @@ patterns.add(patterns.getContext(), options, benefit); + UnrollTransposePattern>(patterns.getContext(), options, benefit); }