diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1760,10 +1760,10 @@ Vector_Op<"gather">, Arguments<(ins Arg:$base, Variadic:$indices, - VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec, - VectorOfRankAndType<[1], [I1]>:$mask, - VectorOfRank<[1]>:$pass_thru)>, - Results<(outs VectorOfRank<[1]>:$result)> { + VectorOf<[AnyInteger, Index]>:$index_vec, + VectorOf<[I1]>:$mask, + AnyVector:$pass_thru)>, + Results<(outs AnyVector:$result)> { let summary = [{ gathers elements from memory or ranked tensor into a vector as defined by an diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -82,14 +82,6 @@ std::function createOperand, ConversionPatternRewriter &rewriter) { auto resultNDVectorType = op->getResult(0).getType().cast(); - - SmallVector operand1DVectorTypes; - for (Value operand : op->getOperands()) { - auto operandNDVectorType = operand.getType().cast(); - auto operandTypeInfo = - extractNDVectorTypeInfo(operandNDVectorType, typeConverter); - operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy); - } auto resultTypeInfo = extractNDVectorTypeInfo(resultNDVectorType, typeConverter); auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -91,24 +91,26 @@ return success(); } -// Add an index vector component to a base pointer. This almost always succeeds -// unless the last stride is non-unit or the memory space is not zero. -static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, - Location loc, Value memref, Value base, - Value index, MemRefType memRefType, - VectorType vType, Value &ptrs) { +// Check if the last stride is non-unit or the memory space is not zero. +static LogicalResult isMemRefTypeSupported(MemRefType memRefType) { int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(memRefType, strides, offset); if (failed(successStrides) || strides.back() != 1 || memRefType.getMemorySpaceAsInt() != 0) return failure(); - auto pType = MemRefDescriptor(memref).getElementPtrType(); - auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0)); - ptrs = rewriter.create(loc, ptrsType, base, index); return success(); } +// Add an index vector component to a base pointer. +static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, + Value memref, Value base, Value index, + uint64_t vLen) { + auto pType = MemRefDescriptor(memref).getElementPtrType(); + auto ptrsType = LLVM::getFixedVectorType(pType, vLen); + return rewriter.create(loc, ptrsType, base, index); +} + // Casts a strided element pointer to a vector pointer. The vector pointer // will be in the same address space as the incoming memref type. static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, @@ -257,29 +259,52 @@ LogicalResult matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = gather->getLoc(); MemRefType memRefType = gather.getBaseType().dyn_cast(); assert(memRefType && "The base should be bufferized"); + if (failed(isMemRefTypeSupported(memRefType))) + return failure(); + + auto loc = gather->getLoc(); + // Resolve alignment. unsigned align; if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) return failure(); - // Resolve address. - Value ptrs; - VectorType vType = gather.getVectorType(); Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), adaptor.getIndices(), rewriter); - if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr, - adaptor.getIndexVec(), memRefType, vType, ptrs))) - return failure(); + Value base = adaptor.getBase(); + + auto llvmNDVectorTy = adaptor.getIndexVec().getType(); + if (!llvmNDVectorTy.isa()) { + auto vType = gather.getVectorType(); + // Resolve address. + Value ptrs = + getIndexedPtrs(rewriter, loc, base, ptr, adaptor.getIndexVec(), + /*vLen=*/vType.getDimSize(0)); + // Replace with the gather intrinsic. + rewriter.replaceOpWithNewOp( + gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), + adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); + return success(); + } - // Replace with the gather intrinsic. - rewriter.replaceOpWithNewOp( - gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), - adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); - return success(); + auto callback = [align, base, ptr, loc, &rewriter](Type llvm1DVectorTy, + ValueRange operands) { + // Resolve address. + Value ptrs = getIndexedPtrs( + rewriter, loc, base, ptr, /*index=*/operands[0], + LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()); + // Create the gather intrinsic. + return rewriter.create( + loc, llvm1DVectorTy, ptrs, /*mask=*/operands[1], + /*passThru=*/operands[2], rewriter.getI32IntegerAttr(align)); + }; + return LLVM::detail::handleMultidimensionalVectors( + gather, + {adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()}, + *getTypeConverter(), callback, rewriter); } }; @@ -295,19 +320,21 @@ auto loc = scatter->getLoc(); MemRefType memRefType = scatter.getMemRefType(); + if (failed(isMemRefTypeSupported(memRefType))) + return failure(); + // Resolve alignment. unsigned align; if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) return failure(); // Resolve address. - Value ptrs; VectorType vType = scatter.getVectorType(); Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), adaptor.getIndices(), rewriter); - if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr, - adaptor.getIndexVec(), memRefType, vType, ptrs))) - return failure(); + Value ptrs = + getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr, + adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0)); // Replace with the scatter intrinsic. rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4074,9 +4074,9 @@ return emitOpError("base and result element type should match"); if (llvm::size(getIndices()) != baseType.getRank()) return emitOpError("requires ") << baseType.getRank() << " indices"; - if (resVType.getDimSize(0) != indVType.getDimSize(0)) + if (resVType.getShape() != indVType.getShape()) return emitOpError("expected result dim to match indices dim"); - if (resVType.getDimSize(0) != maskVType.getDimSize(0)) + if (resVType.getShape() != maskVType.getShape()) return emitOpError("expected result dim to match mask dim"); if (resVType != getPassThruVectorType()) return emitOpError("expected pass_thru of same type as result type");