diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -819,10 +819,10 @@ Option<"reassociateFPReductions", "reassociate-fp-reductions", "bool", /*default=*/"false", "Allows llvm to reassociate floating-point reductions for speed">, - Option<"indexOptimizations", "enable-index-optimizations", + Option<"force32BitVectorIndices", "force-32bit-vector-indices", "bool", /*default=*/"true", - "Allows compiler to assume indices fit in 32-bit if that yields " - "faster code">, + "Allows compiler to assume vector indices fit in 32-bit if that " + "yields faster code">, Option<"amx", "enable-amx", "bool", /*default=*/"false", "Enables the use of AMX dialect while lowering the vector " diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -28,7 +28,7 @@ return *this; } LowerVectorToLLVMOptions &enableIndexOptimizations(bool b = true) { - indexOptimizations = b; + force32BitVectorIndices = b; return *this; } LowerVectorToLLVMOptions &enableArmNeon(bool b = true) { @@ -49,7 +49,7 @@ } bool reassociateFPReductions{false}; - bool indexOptimizations{true}; + bool force32BitVectorIndices{true}; bool armNeon{false}; bool armSVE{false}; bool amx{false}; @@ -63,10 +63,9 @@ RewritePatternSet &patterns); /// Collect a set of patterns to convert from the Vector dialect to LLVM. -/// If `indexOptimizations` is set, assume indices fit in 32-bit. void populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns, - bool reassociateFPReductions = false, bool indexOptimizations = false); + bool reassociateFPReductions = false, bool force32BitVectorIndices = false); /// Create a pass to convert vector operations to the LLVMIR dialect. std::unique_ptr> createConvertVectorToLLVMPass( diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -102,7 +102,7 @@ /// These patterns materialize masks for various vector ops such as transfers. void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, - bool indexOptimizations); + bool force32BitVectorIndices); /// Collect a set of patterns to propagate insert_map/extract_map in the ssa /// chain. diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -909,7 +909,7 @@ explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, bool enableIndexOpt) : OpRewritePattern(context), - indexOptimizations(enableIndexOpt) {} + force32BitVectorIndices(enableIndexOpt) {} LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override { @@ -917,7 +917,7 @@ if (dstType.getRank() != 1 || !dstType.cast().isScalable()) return failure(); IntegerType idxType = - indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type(); + force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); auto loc = op->getLoc(); Value indices = rewriter.create( loc, LLVM::getVectorType(idxType, dstType.getShape()[0], @@ -932,7 +932,7 @@ } private: - const bool indexOptimizations; + const bool force32BitVectorIndices; }; class VectorPrintOpConversion : public ConvertOpToLLVMPattern { @@ -1192,15 +1192,14 @@ } // namespace /// Populate the given list with patterns that convert from Vector to LLVM. -void mlir::populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns, - bool reassociateFPReductions, - bool indexOptimizations) { +void mlir::populateVectorToLLVMConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns, + bool reassociateFPReductions, bool force32BitVectorIndices) { MLIRContext *ctx = converter.getDialect()->getContext(); patterns.add(ctx); populateVectorInsertExtractStridedSliceTransforms(patterns); patterns.add(converter, reassociateFPReductions); - patterns.add(ctx, indexOptimizations); + patterns.add(ctx, force32BitVectorIndices); patterns .add { LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { this->reassociateFPReductions = options.reassociateFPReductions; - this->indexOptimizations = options.indexOptimizations; + this->force32BitVectorIndices = options.force32BitVectorIndices; this->armNeon = options.armNeon; this->armSVE = options.armSVE; this->amx = options.amx; @@ -77,11 +77,11 @@ // Convert to the LLVM IR dialect. LLVMTypeConverter converter(&getContext()); RewritePatternSet patterns(&getContext()); - populateVectorMaskMaterializationPatterns(patterns, indexOptimizations); + populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices); populateVectorTransferLoweringPatterns(patterns); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns( - converter, patterns, reassociateFPReductions, indexOptimizations); + converter, patterns, reassociateFPReductions, force32BitVectorIndices); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); // Architecture specific augmentations. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -2185,22 +2185,22 @@ // generates more elaborate instructions for this intrinsic since it // is very conservative on the boundary conditions. static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, - bool indexOptimizations, int64_t dim, + bool force32BitVectorIndices, int64_t dim, Value b, Value *off = nullptr) { auto loc = op->getLoc(); // If we can assume all indices fit in 32-bit, we perform the vector // comparison in 32-bit to get a higher degree of SIMD parallelism. // Otherwise we perform the vector comparison using 64-bit indices. Type idxType = - indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type(); + force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); DenseIntElementsAttr indicesAttr; - if (dim == 0 && indexOptimizations) { + if (dim == 0 && force32BitVectorIndices) { indicesAttr = DenseIntElementsAttr::get( VectorType::get(ArrayRef{}, idxType), ArrayRef{0}); } else if (dim == 0) { indicesAttr = DenseIntElementsAttr::get( VectorType::get(ArrayRef{}, idxType), ArrayRef{0}); - } else if (indexOptimizations) { + } else if (force32BitVectorIndices) { indicesAttr = rewriter.getI32VectorAttr( llvm::to_vector<4>(llvm::seq(0, dim))); } else { @@ -2227,7 +2227,7 @@ public: explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt) : mlir::OpRewritePattern(context), - indexOptimizations(enableIndexOpt) {} + force32BitVectorIndices(enableIndexOpt) {} LogicalResult matchAndRewrite(ConcreteOp xferOp, PatternRewriter &rewriter) const override { @@ -2270,7 +2270,7 @@ } private: - const bool indexOptimizations; + const bool force32BitVectorIndices; }; /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only). @@ -2280,7 +2280,7 @@ explicit VectorCreateMaskOpConversion(MLIRContext *context, bool enableIndexOpt) : mlir::OpRewritePattern(context), - indexOptimizations(enableIndexOpt) {} + force32BitVectorIndices(enableIndexOpt) {} LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override { @@ -2291,14 +2291,14 @@ if (rank > 1) return failure(); rewriter.replaceOp( - op, buildVectorComparison(rewriter, op, indexOptimizations, + op, buildVectorComparison(rewriter, op, force32BitVectorIndices, rank == 0 ? 0 : dstType.getDimSize(0), op.getOperand(0))); return success(); } private: - const bool indexOptimizations; + const bool force32BitVectorIndices; }; // Drop inner most contiguous unit dimensions from transfer_read operand. @@ -2592,11 +2592,11 @@ } // namespace void mlir::vector::populateVectorMaskMaterializationPatterns( - RewritePatternSet &patterns, bool indexOptimizations) { + RewritePatternSet &patterns, bool force32BitVectorIndices) { patterns.add, MaterializeTransferMask>( - patterns.getContext(), indexOptimizations); + patterns.getContext(), force32BitVectorIndices); } void mlir::vector::populateShapeCastFoldingPatterns( diff --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s --convert-vector-to-llvm='enable-index-optimizations=1' | FileCheck %s --check-prefix=CMP32 -// RUN: mlir-opt %s --convert-vector-to-llvm='enable-index-optimizations=0' | FileCheck %s --check-prefix=CMP64 +// RUN: mlir-opt %s --convert-vector-to-llvm='force-32bit-vector-indices=1' | FileCheck %s --check-prefix=CMP32 +// RUN: mlir-opt %s --convert-vector-to-llvm='force-32bit-vector-indices=0' | FileCheck %s --check-prefix=CMP64 // CMP32-LABEL: @genbool_var_1d( // CMP32-SAME: %[[ARG:.*]]: index)