diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1763,21 +1763,21 @@ def Vector_GatherOp : Vector_Op<"gather">, - Arguments<(ins Arg:$base, + Arguments<(ins Arg:$base, Variadic:$indices, VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec, VectorOfRankAndType<[1], [I1]>:$mask, VectorOfRank<[1]>:$pass_thru)>, Results<(outs VectorOfRank<[1]>:$result)> { - let summary = "gathers elements from memory into a vector as defined by an index vector and mask"; + let summary = "gathers elements from memory or ranked tensor into a vector as defined by an index vector and mask"; let description = [{ - The gather operation gathers elements from memory into a 1-D vector as - defined by a base with indices and an additional 1-D index vector, but - only if the corresponding bit is set in a 1-D mask vector. Otherwise, the - element is taken from a 1-D pass-through vector. Informally the semantics - are: + The gather operation gathers elements from memory or ranked tensor into a + 1-D vector as defined by a base with indices and an additional 1-D index + vector, but only if the corresponding bit is set in a 1-D mask vector. + Otherwise, the element is taken from a 1-D pass-through vector. Informally + the semantics are: ``` result[0] := mask[0] ? base[index[0]] : pass_thru[0] result[1] := mask[1] ? base[index[1]] : pass_thru[1] @@ -1802,8 +1802,8 @@ ``` }]; let extraClassDeclaration = [{ - MemRefType getMemRefType() { - return getBase().getType().cast(); + ShapedType getBaseType() { + return getBase().getType().cast(); } VectorType getIndexVectorType() { return getIndexVec().getType().cast(); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -260,7 +260,8 @@ matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = gather->getLoc(); - MemRefType memRefType = gather.getMemRefType(); + // The base should have been bufferized. + MemRefType memRefType = gather.getBaseType().cast(); // Resolve alignment. unsigned align; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -3985,12 +3985,15 @@ VectorType indVType = getIndexVectorType(); VectorType maskVType = getMaskVectorType(); VectorType resVType = getVectorType(); - MemRefType memType = getMemRefType(); + ShapedType baseType = getBaseType(); - if (resVType.getElementType() != memType.getElementType()) + if (!baseType.isa()) + return emitOpError("requires base to be a memref or ranked tensor type"); + + if (resVType.getElementType() != baseType.getElementType()) return emitOpError("base and result element type should match"); - if (llvm::size(getIndices()) != memType.getRank()) - return emitOpError("requires ") << memType.getRank() << " indices"; + if (llvm::size(getIndices()) != baseType.getRank()) + return emitOpError("requires ") << baseType.getRank() << " indices"; if (resVType.getDimSize(0) != indVType.getDimSize(0)) return emitOpError("expected result dim to match indices dim"); if (resVType.getDimSize(0) != maskVType.getDimSize(0)) diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -113,6 +113,46 @@ } }; +/// Bufferization of vector.gather. Replaced with a new vector.gather that +/// operates on a memref. +struct GatherOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + assert(opOperand.get().getType().isa() && + "only tensor types expected"); + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + assert(opOperand.get().getType().isa() && + "only tensor types expected"); + return false; + } + + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + auto gatherOp = cast(op); + assert(gatherOp.getBaseType().isa() && + "only tensor types expected"); + FailureOr buffer = getBuffer(rewriter, gatherOp.getBase(), options); + if (failed(buffer)) + return failure(); + replaceOpWithNewBufferizedOp( + rewriter, gatherOp, gatherOp.getVectorType(), *buffer, + gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(), + gatherOp.getPassThru()); + return success(); + } +}; + } // namespace } // namespace vector } // namespace mlir @@ -122,5 +162,6 @@ registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) { TransferReadOp::attachInterface(*ctx); TransferWriteOp::attachInterface(*ctx); + GatherOp::attachInterface(*ctx); }); } diff --git a/mlir/test/Dialect/Vector/bufferize.mlir b/mlir/test/Dialect/Vector/bufferize.mlir --- a/mlir/test/Dialect/Vector/bufferize.mlir +++ b/mlir/test/Dialect/Vector/bufferize.mlir @@ -29,3 +29,17 @@ : vector<5x6xf32>, tensor return %0 : tensor } + +// ----- + +// CHECK-LABEL: func @gather( +// CHECK-SAME: %[[base:.*]]: tensor, %[[v:.*]]: vector<16xi32>, +// CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[pass_thru:.*]]: vector<16xf32>) +// CHECK: %[[m:.*]] = bufferization.to_memref %[[base]] : memref +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[out:.*]] = vector.gather %[[m]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[pass_thru]] : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> +func.func @gather(%base: tensor, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) -> vector<16xf32> { + %c0 = arith.constant 0 : index + %0 = vector.gather %base[%c0, %c0][%v], %mask, %pass_thru : tensor, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + return %0 : vector<16xf32> +} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -669,6 +669,14 @@ return } +// CHECK-LABEL: @gather_on_tensor +func.func @gather_on_tensor(%base: tensor, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) -> vector<16xf32> { + %c0 = arith.constant 0 : index + // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : tensor, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + return %0 : vector<16xf32> +} + // CHECK-LABEL: @expand_and_compress func.func @expand_and_compress(%base: memref, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { %c0 = arith.constant 0 : index