diff --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/CMakeLists.txt @@ -31,4 +31,5 @@ MLIRDataLayoutInterfaces MLIRSideEffectInterfaces MLIRVectorInterfaces + MLIRLLVMIR ) diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" @@ -2245,10 +2246,26 @@ int64_t rank = dstType.getRank(); if (rank > 1) return failure(); - rewriter.replaceOp( - op, buildVectorComparison(rewriter, op, indexOptimizations, - rank == 0 ? 0 : dstType.getDimSize(0), - op.getOperand(0))); + // In case of it being a scalable mask, use LLVM::GetActiveLaneMaskOp + if (dstType.cast().isScalable()) { + if (rank == 0) + return failure(); + Type idxType = + indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type(); + auto loc = op->getLoc(); + Value zero = rewriter.create( + loc, idxType, rewriter.getI32IntegerAttr(0)); + Value bound = + createCastToIndexLike(rewriter, loc, idxType, op.getOperand(0)); + Value getActiveLaneMask = + rewriter.create(loc, dstType, zero, bound); + rewriter.replaceOp(op, getActiveLaneMask); + } else { + rewriter.replaceOp( + op, buildVectorComparison(rewriter, op, indexOptimizations, + rank == 0 ? 0 : dstType.getDimSize(0), + op.getOperand(0))); + } return success(); }