diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -65,7 +65,7 @@ /// Collect a set of patterns to convert from the Vector dialect to LLVM. void populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns, - bool reassociateFPReductions = false); + bool reassociateFPReductions = false, bool indexOptimizations = false); /// Create a pass to convert vector operations to the LLVMIR dialect. std::unique_ptr> createConvertVectorToLLVMPass( diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -901,6 +901,62 @@ } }; +static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc, + Type targetType, Value value) { + if (targetType == value.getType()) + return value; + + bool targetIsIndex = targetType.isIndex(); + bool valueIsIndex = value.getType().isIndex(); + if (targetIsIndex ^ valueIsIndex) + return rewriter.create(loc, targetType, value); + + auto targetIntegerType = targetType.dyn_cast(); + auto valueIntegerType = value.getType().dyn_cast(); + assert(targetIntegerType && valueIntegerType && + "unexpected cast between types other than integers and index"); + assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); + + if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) + return rewriter.create(loc, targetIntegerType, value); + return rewriter.create(loc, targetIntegerType, value); +} + +/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). +/// Non-scalable versions of this operation are handled in Vector Transforms. +class VectorCreateMaskOpRewritePattern + : public OpRewritePattern { +public: + explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, + bool enableIndexOpt) + : mlir::OpRewritePattern(context), + indexOptimizations(enableIndexOpt) {} + + LogicalResult matchAndRewrite(vector::CreateMaskOp op, + PatternRewriter &rewriter) const override { + auto dstType = op.getType(); + int64_t rank = dstType.getRank(); + if (rank > 1 || rank == 0) + return failure(); + if (!dstType.cast().isScalable()) + return failure(); + auto idxType = + indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type(); + auto loc = op->getLoc(); + auto zero = rewriter.create( + loc, idxType, rewriter.getI32IntegerAttr(0)); + auto bound = + createCastToIndexLike(rewriter, loc, idxType, op.getOperand(0)); + Value getActiveLaneMask = + rewriter.create(loc, dstType, zero, bound); + rewriter.replaceOp(op, getActiveLaneMask); + return success(); + } + +private: + const bool indexOptimizations; +}; + class VectorPrintOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -1067,11 +1123,12 @@ /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns, - bool reassociateFPReductions) { + bool reassociateFPReductions, bool indexOptimizations) { MLIRContext *ctx = converter.getDialect()->getContext(); patterns.add(ctx); populateVectorInsertExtractStridedSliceTransforms(patterns); patterns.add(converter, reassociateFPReductions); + patterns.add(ctx, indexOptimizations); patterns .add().isScalable()) + return failure(); // Gather constant mask dimension sizes. SmallVector maskDimSizes; for (auto it : llvm::zip(createMaskOp.operands(), diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" @@ -2242,6 +2243,8 @@ LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override { auto dstType = op.getType(); + if (dstType.cast().isScalable()) + return failure(); int64_t rank = dstType.getRank(); if (rank > 1) return failure();