diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -325,6 +325,47 @@ return success(); } +// TODO: Improve names. Make it specific to tensor::ExtractOp for now? +enum VectorMemoryAccessKind { + ScalarBroadcast, + Contiguous, + Gather +}; + +bool isStrideOneAlongDimension(Value indexVal, unsigned dim) { + Operation *defOp = indexVal.getDefiningOp(); + if (!defOp) + return false; + if (auto indexOp = dyn_cast(defOp)) + return indexOp.getDim() == dim; + // TODO: explore UD chain. + return false; +} + +VectorMemoryAccessKind +getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, + ArrayRef vectorShape) { + auto indices = extractOp.getIndices(); + + // TODO: Support n-D vectors. + if (llvm::count_if(vectorShape, + [](int64_t dimSize) { return dimSize > 1; }) != 1) + return VectorMemoryAccessKind::Gather; + + bool isContiguous = true; + for (auto [i, indexVal] : llvm::enumerate(indices)) { + if (vectorShape[i] == 1) + continue; + // TODO: Broadcast case. + isContiguous &= isStrideOneAlongDimension(indexVal, i); + } + + if (isContiguous) + return VectorMemoryAccessKind::Contiguous; + + return VectorMemoryAccessKind::Gather; +} + /// Helper function to vectorize the tensor.extract operations. Returns /// VectorizationStatus::NewOp to signal the vectorization algorithm that it /// should map the produced operations. This function is meant to be used as a @@ -340,6 +381,9 @@ // Compute the static loop sizes of the extract op. auto targetShape = linalgOp.computeStaticLoopSizes(); + VectorMemoryAccessKind memAccessKind = + getTensorExtractMemoryAccessPattern(extractOp, targetShape); + auto resultType = VectorType::get(targetShape, extractOp.getResult().getType()); auto maskConstantOp = b.create(