diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2217,6 +2217,7 @@ def Vector_TransposeOp : Vector_Op<"transpose", [NoSideEffect, + DeclareOpInterfaceMethods, PredOpTrait<"operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>]>, Arguments<(ins AnyVector:$vector, I64ArrayAttr:$transp)>, diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4320,6 +4320,10 @@ return success(); } +Optional> TransposeOp::getShapeForUnroll() { + return llvm::to_vector<4>(getResultType().getShape()); +} + namespace { // Rewrites two back-to-back TransposeOp operations into a single TransposeOp. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp @@ -681,14 +681,62 @@ const vector::UnrollVectorOptions options; }; +struct UnrollTranposePattern : public OpRewritePattern { + UnrollTranposePattern(MLIRContext *context, + const vector::UnrollVectorOptions &options) + : OpRewritePattern(context, /*benefit=*/1), + options(options) {} + LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp, + PatternRewriter &rewriter) const override { + if (tranposeOp.getResultType().getRank() == 0) + return failure(); + auto targetShape = getTargetShape(options, tranposeOp); + if (!targetShape) + return failure(); + auto originalVectorType = tranposeOp.getResultType(); + SmallVector strides(targetShape->size(), 1); + Location loc = tranposeOp.getLoc(); + ArrayRef originalSize = originalVectorType.getShape(); + SmallVector ratio = *shapeRatio(originalSize, *targetShape); + int64_t sliceCount = computeMaxLinearIndex(ratio); + // Prepare the result vector; + Value result = rewriter.create( + loc, originalVectorType, rewriter.getZeroAttr(originalVectorType)); + SmallVector permutation; + tranposeOp.getTransp(permutation); + for (int64_t i = 0; i < sliceCount; i++) { + SmallVector elementOffsets = + getVectorOffset(originalSize, *targetShape, i); + SmallVector permutedOffsets(elementOffsets.size()); + SmallVector permutedShape(elementOffsets.size()); + // Compute the source offsets and shape. + for (auto &indices : llvm::enumerate(permutation)) { + permutedOffsets[indices.value()] = elementOffsets[indices.index()]; + permutedShape[indices.value()] = (*targetShape)[indices.index()]; + } + Value slicedOperand = rewriter.create( + loc, tranposeOp.getVector(), permutedOffsets, permutedShape, strides); + Value tranposedSlice = + rewriter.create(loc, slicedOperand, permutation); + result = rewriter.create( + loc, tranposedSlice, result, elementOffsets, strides); + } + rewriter.replaceOp(tranposeOp, result); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + } // namespace void mlir::vector::populateVectorUnrollPatterns( RewritePatternSet &patterns, const UnrollVectorOptions &options) { patterns.add( - patterns.getContext(), options); + UnrollReductionPattern, UnrollMultiReductionPattern, + UnrollTranposePattern>(patterns.getContext(), options); } void mlir::vector::populatePropagateVectorDistributionPatterns( diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -107,6 +107,11 @@ // CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[A3]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> // CHECK: return %[[V2]] : vector<4xf32> + +func @vector_reduction(%v : vector<8xf32>) -> f32 { + %0 = vector.reduction , %v : vector<8xf32> into f32 + return %0 : f32 +} // CHECK-LABEL: func @vector_reduction( // CHECK-SAME: %[[v:.*]]: vector<8xf32> // CHECK: %[[s0:.*]] = vector.extract_strided_slice %[[v]] {offsets = [0], sizes = [2] @@ -121,8 +126,35 @@ // CHECK: %[[r3:.*]] = vector.reduction , %[[s3]] // CHECK: %[[add3:.*]] = arith.addf %[[add2]], %[[r3]] // CHECK: return %[[add3]] -func @vector_reduction(%v : vector<8xf32>) -> f32 { - %0 = vector.reduction , %v : vector<8xf32> into f32 - return %0 : f32 -} +func @vector_tranpose(%v : vector<2x4x3x8xf32>) -> vector<2x3x8x4xf32> { + %t = vector.transpose %v, [0, 2, 3, 1] : vector<2x4x3x8xf32> to vector<2x3x8x4xf32> + return %t : vector<2x3x8x4xf32> +} +// CHECK-LABEL: func @vector_tranpose +// CHECK: %[[VI:.*]] = arith.constant dense<0.000000e+00> : vector<2x3x8x4xf32> +// CHECK: %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32> +// CHECK: %[[T0:.*]] = vector.transpose %[[E0]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32> +// CHECK: %[[V0:.*]] = vector.insert_strided_slice %[[T0]], %[[VI]] {offsets = [0, 0, 0, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32> +// CHECK: %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32> +// CHECK: %[[T1:.*]] = vector.transpose %[[E1]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32> +// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[T1]], %[[V0]] {offsets = [0, 0, 0, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32> +// CHECK: %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32> +// CHECK: %[[T2:.*]] = vector.transpose %[[E2]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32> +// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[T2]], %[[V1]] {offsets = [0, 0, 4, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32> +// CHECK: %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32> +// CHECK: %[[T3:.*]] = vector.transpose %[[E3]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32> +// CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[T3]], %[[V2]] {offsets = [0, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32> +// CHECK: %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32> +// CHECK: %[[T4:.*]] = vector.transpose %[[E4]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32> +// CHECK: %[[V4:.*]] = vector.insert_strided_slice %[[T4]], %[[V3]] {offsets = [1, 0, 0, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32> +// CHECK: %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 2, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32> +// CHECK: %[[T5:.*]] = vector.transpose %[[E5]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32> +// CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[T5]], %[[V4]] {offsets = [1, 0, 0, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32> +// CHECK: %[[E6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32> +// CHECK: %[[T6:.*]] = vector.transpose %[[E6]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32> +// CHECK: %[[V6:.*]] = vector.insert_strided_slice %[[T6]], %[[V5]] {offsets = [1, 0, 4, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32> +// CHECK: %[[E7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 2, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32> +// CHECK: %[[T7:.*]] = vector.transpose %[[E7]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32> +// CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[T7]], %[[V6]] {offsets = [1, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32> +// CHECK: return %[[V7]] : vector<2x3x8x4xf32> diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -282,6 +282,12 @@ .setFilterConstraint([](Operation *op) { return success(isa(op)); })); + populateVectorUnrollPatterns( + patterns, UnrollVectorOptions() + .setNativeShape(ArrayRef{1, 3, 4, 2}) + .setFilterConstraint([](Operation *op) { + return success(isa(op)); + })); if (unrollBasedOnType) { UnrollVectorOptions::NativeShapeFnType nativeShapeFn =