diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1838,7 +1838,10 @@ } def Vector_GatherOp : - Vector_Op<"gather", [DeclareOpInterfaceMethods]>, + Vector_Op<"gather", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods + ]>, Arguments<(ins Arg:$base, Variadic:$indices, VectorOf<[AnyInteger, Index]>:$index_vec, diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4630,6 +4630,10 @@ IntegerType::get(vecType.getContext(), /*width=*/1)); } +std::optional> GatherOp::getShapeForUnroll() { + return llvm::to_vector<4>(getVectorType().getShape()); +} + namespace { class GatherFolder final : public OpRewritePattern { public: diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -641,6 +641,61 @@ vector::UnrollVectorOptions options; }; +struct UnrollGatherPattern : public OpRewritePattern { + UnrollGatherPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), options(options) { + } + + LogicalResult matchAndRewrite(vector::GatherOp gatherOp, + PatternRewriter &rewriter) const override { + VectorType sourceVectorType = gatherOp.getVectorType(); + if (sourceVectorType.getRank() == 0) + return failure(); + auto targetShape = getTargetShape(options, gatherOp); + if (!targetShape) + return failure(); + SmallVector strides(targetShape->size(), 1); + Location loc = gatherOp.getLoc(); + ArrayRef originalSize = gatherOp.getVectorType().getShape(); + + // Prepare the result vector; + Value result = rewriter.create( + loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); + auto targetType = + VectorType::get(*targetShape, sourceVectorType.getElementType()); + + SmallVector loopOrder = + getUnrollOrder(originalSize.size(), gatherOp, options); + DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, + loopOrder); + for (int64_t i = 0, e = indexToOffsets.maxIndex(); i < e; ++i) { + // To get the unrolled gather, extract the same slice based on the + // decomposed shape from each of the index, mask, and pass-through + // vectors. + SmallVector elementOffsets = indexToOffsets.getVectorOffset(i); + Value indexSubVec = rewriter.create( + loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides); + Value maskSubVec = rewriter.create( + loc, gatherOp.getMask(), elementOffsets, *targetShape, strides); + Value passThruSubVec = rewriter.create( + loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides); + auto slicedGather = rewriter.create( + loc, targetType, gatherOp.getBase(), gatherOp.getIndices(), + indexSubVec, maskSubVec, passThruSubVec); + + result = rewriter.create( + loc, slicedGather, result, elementOffsets, strides); + } + rewriter.replaceOp(gatherOp, result); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + } // namespace void mlir::vector::populateVectorUnrollPatterns( @@ -649,5 +704,6 @@ patterns.add(patterns.getContext(), options, benefit); + UnrollTransposePattern, UnrollGatherPattern>( + patterns.getContext(), options, benefit); } diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir @@ -275,3 +275,90 @@ %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %cf0 {permutation_map = #map0} : memref, vector<6x4xf32> return %0 : vector<6x4xf32> } + +// ----- + +// CHECK-LABEL: func @vector_gather_unroll +// CHECK-SAME: %[[ARG0:.*]]: memref +// CHECK-SAME: %[[ARG1:.*]]: vector<6x4xindex> +// CHECK-SAME: %[[ARG2:.*]]: vector<6x4xi1> +// CHECK-SAME: %[[ARG3:.*]]: vector<6x4xf32> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[IDX0:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xindex> to vector<2x2xindex> +// CHECK-NEXT: %[[MASK0:.*]] = vector.extract_strided_slice %[[ARG2]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xi1> to vector<2x2xi1> +// CHECK-NEXT: %[[PASS0:.*]] = vector.extract_strided_slice %[[ARG3]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[VGT0:.*]] = vector.gather {{.*}}[%[[C0]], %[[C0]], %[[C0]]] [%[[IDX0]]], %[[MASK0]], %[[PASS0]] : memref, vector<2x2xindex>, vector<2x2xi1>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VGT0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> +// CHECK-NEXT: %[[IDX1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xindex> to vector<2x2xindex> +// CHECK-NEXT: %[[MASK1:.*]] = vector.extract_strided_slice %[[ARG2]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xi1> to vector<2x2xi1> +// CHECK-NEXT: %[[PASS1:.*]] = vector.extract_strided_slice %[[ARG3]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[VGT1:.*]] = vector.gather {{.*}}[%[[C0]], %[[C0]], %[[C0]]] [%[[IDX1]]], %[[MASK1]], %[[PASS1]] : memref, vector<2x2xindex>, vector<2x2xi1>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VGT1]], %[[VEC0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> +// CHECK-NEXT: %[[IDX2:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xindex> to vector<2x2xindex> +// CHECK-NEXT: %[[MASK2:.*]] = vector.extract_strided_slice %[[ARG2]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xi1> to vector<2x2xi1> +// CHECK-NEXT: %[[PASS2:.*]] = vector.extract_strided_slice %[[ARG3]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[VGT2:.*]] = vector.gather {{.*}}[%[[C0]], %[[C0]], %[[C0]]] [%[[IDX2]]], %[[MASK2]], %[[PASS2]] : memref, vector<2x2xindex>, vector<2x2xi1>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VGT2]], %[[VEC1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> +// CHECK-NEXT: %[[IDX3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xindex> to vector<2x2xindex> +// CHECK-NEXT: %[[MASK3:.*]] = vector.extract_strided_slice %[[ARG2]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xi1> to vector<2x2xi1> +// CHECK-NEXT: %[[PASS3:.*]] = vector.extract_strided_slice %[[ARG3]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[VGT3:.*]] = vector.gather {{.*}}[%[[C0]], %[[C0]], %[[C0]]] [%[[IDX3]]], %[[MASK3]], %[[PASS3]] : memref, vector<2x2xindex>, vector<2x2xi1>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VGT3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> +// CHECK-NEXT: %[[IDX4:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xindex> to vector<2x2xindex> +// CHECK-NEXT: %[[MASK4:.*]] = vector.extract_strided_slice %[[ARG2]] {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xi1> to vector<2x2xi1> +// CHECK-NEXT: %[[PASS4:.*]] = vector.extract_strided_slice %[[ARG3]] {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[VGT4:.*]] = vector.gather {{.*}}[%[[C0]], %[[C0]], %[[C0]]] [%[[IDX4]]], %[[MASK4]], %[[PASS4]] : memref, vector<2x2xindex>, vector<2x2xi1>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[VEC4:.*]] = vector.insert_strided_slice %[[VGT4]], %[[VEC3]] {offsets = [4, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> +// CHECK-NEXT: %[[IDX5:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xindex> to vector<2x2xindex> +// CHECK-NEXT: %[[MASK5:.*]] = vector.extract_strided_slice %[[ARG2]] {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xi1> to vector<2x2xi1> +// CHECK-NEXT: %[[PASS5:.*]] = vector.extract_strided_slice %[[ARG3]] {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[VGT5:.*]] = vector.gather {{.*}}[%[[C0]], %[[C0]], %[[C0]]] [%[[IDX5]]], %[[MASK5]], %[[PASS5]] : memref, vector<2x2xindex>, vector<2x2xi1>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[VEC5:.*]] = vector.insert_strided_slice %[[VGT5]], %[[VEC4]] {offsets = [4, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> +// CHECK-NEXT: return %[[VEC5]] : vector<6x4xf32> + +// ORDER-LABEL: func @vector_gather_unroll +// ORDER-SAME: %[[ARG0:.*]]: memref +// ORDER-SAME: %[[ARG1:.*]]: vector<6x4xindex> +// ORDER-SAME: %[[ARG2:.*]]: vector<6x4xi1> +// ORDER-SAME: %[[ARG3:.*]]: vector<6x4xf32> +// ORDER-DAG: %[[C0:.*]] = arith.constant 0 : index +// ORDER: %[[IDX0:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xindex> to vector<2x2xindex> +// ORDER-NEXT: %[[MASK0:.*]] = vector.extract_strided_slice %[[ARG2]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xi1> to vector<2x2xi1> +// ORDER-NEXT: %[[PASS0:.*]] = vector.extract_strided_slice %[[ARG3]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// ORDER-NEXT: %[[VGT0:.*]] = vector.gather {{.*}}[%[[C0]], %[[C0]], %[[C0]]] [%[[IDX0]]], %[[MASK0]], %[[PASS0]] : memref, vector<2x2xindex>, vector<2x2xi1>, vector<2x2xf32> into vector<2x2xf32> +// ORDER-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VGT0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> +// ORDER-NEXT: %[[IDX1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xindex> to vector<2x2xindex> +// ORDER-NEXT: %[[MASK1:.*]] = vector.extract_strided_slice %[[ARG2]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xi1> to vector<2x2xi1> +// ORDER-NEXT: %[[PASS1:.*]] = vector.extract_strided_slice %[[ARG3]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// ORDER-NEXT: %[[VGT1:.*]] = vector.gather {{.*}}[%[[C0]], %[[C0]], %[[C0]]] [%[[IDX1]]], %[[MASK1]], %[[PASS1]] : memref, vector<2x2xindex>, vector<2x2xi1>, vector<2x2xf32> into vector<2x2xf32> +// ORDER-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VGT1]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> +// ORDER-NEXT: %[[IDX2:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xindex> to vector<2x2xindex> +// ORDER-NEXT: %[[MASK2:.*]] = vector.extract_strided_slice %[[ARG2]] {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xi1> to vector<2x2xi1> +// ORDER-NEXT: %[[PASS2:.*]] = vector.extract_strided_slice %[[ARG3]] {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// ORDER-NEXT: %[[VGT2:.*]] = vector.gather {{.*}}[%[[C0]], %[[C0]], %[[C0]]] [%[[IDX2]]], %[[MASK2]], %[[PASS2]] : memref, vector<2x2xindex>, vector<2x2xi1>, vector<2x2xf32> into vector<2x2xf32> +// ORDER-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VGT2]], %[[VEC1]] {offsets = [4, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> +// ORDER-NEXT: %[[IDX3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xindex> to vector<2x2xindex> +// ORDER-NEXT: %[[MASK3:.*]] = vector.extract_strided_slice %[[ARG2]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xi1> to vector<2x2xi1> +// ORDER-NEXT: %[[PASS3:.*]] = vector.extract_strided_slice %[[ARG3]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// ORDER-NEXT: %[[VGT3:.*]] = vector.gather {{.*}}[%[[C0]], %[[C0]], %[[C0]]] [%[[IDX3]]], %[[MASK3]], %[[PASS3]] : memref, vector<2x2xindex>, vector<2x2xi1>, vector<2x2xf32> into vector<2x2xf32> +// ORDER-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VGT3]], %[[VEC2]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> +// ORDER-NEXT: %[[IDX4:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xindex> to vector<2x2xindex> +// ORDER-NEXT: %[[MASK4:.*]] = vector.extract_strided_slice %[[ARG2]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xi1> to vector<2x2xi1> +// ORDER-NEXT: %[[PASS4:.*]] = vector.extract_strided_slice %[[ARG3]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// ORDER-NEXT: %[[VGT4:.*]] = vector.gather {{.*}}[%[[C0]], %[[C0]], %[[C0]]] [%[[IDX4]]], %[[MASK4]], %[[PASS4]] : memref, vector<2x2xindex>, vector<2x2xi1>, vector<2x2xf32> into vector<2x2xf32> +// ORDER-NEXT: %[[VEC4:.*]] = vector.insert_strided_slice %[[VGT4]], %[[VEC3]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> +// ORDER-NEXT: %[[IDX5:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xindex> to vector<2x2xindex> +// ORDER-NEXT: %[[MASK5:.*]] = vector.extract_strided_slice %[[ARG2]] {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xi1> to vector<2x2xi1> +// ORDER-NEXT: %[[PASS5:.*]] = vector.extract_strided_slice %[[ARG3]] {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// ORDER-NEXT: %[[VGT5:.*]] = vector.gather {{.*}}[%[[C0]], %[[C0]], %[[C0]]] [%[[IDX5]]], %[[MASK5]], %[[PASS5]] : memref, vector<2x2xindex>, vector<2x2xi1>, vector<2x2xf32> into vector<2x2xf32> +// ORDER-NEXT: %[[VEC5:.*]] = vector.insert_strided_slice %[[VGT5]], %[[VEC4]] {offsets = [4, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> +// ORDER-NEXT: return %[[VEC5]] : vector<6x4xf32> + +func.func @vector_gather_unroll(%arg0 : memref, + %indices : vector<6x4xindex>, + %mask : vector<6x4xi1>, + %pass_thru : vector<6x4xf32>) -> vector<6x4xf32> { + %c0 = arith.constant 0 : index + %0 = vector.gather %arg0[%c0, %c0, %c0] [%indices], %mask, %pass_thru : memref, vector<6x4xindex>, vector<6x4xi1>, vector<6x4xf32> into vector<6x4xf32> + return %0 : vector<6x4xf32> +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -261,8 +261,8 @@ UnrollVectorOptions opts; opts.setNativeShape(ArrayRef{2, 2}) .setFilterConstraint([](Operation *op) { - return success( - isa(op)); + return success(isa(op)); }); if (reverseUnrollOrder.getValue()) { opts.setUnrollTraversalOrderFn( @@ -272,6 +272,8 @@ numLoops = readOp.getVectorType().getRank(); else if (auto writeOp = dyn_cast(op)) numLoops = writeOp.getVectorType().getRank(); + else if (auto gatherOp = dyn_cast(op)) + numLoops = gatherOp.getVectorType().getRank(); else return std::nullopt; auto order = llvm::reverse(llvm::seq(0, numLoops));