diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -65,7 +65,7 @@ /// Collect a set of patterns to convert from the Vector dialect to LLVM. void populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns, - bool reassociateFPReductions = false); + bool reassociateFPReductions = false, bool indexOptimizations = false); /// Create a pass to convert vector operations to the LLVMIR dialect. std::unique_ptr> createConvertVectorToLLVMPass( diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1791,6 +1791,14 @@ /// Create a call to vscale intrinsic. def LLVM_vscale : LLVM_IntrOp<"vscale", [0], [], [], 1>; +/// Create a call to stepvector intrinsic +def LLVM_StepVectorOp + : LLVM_IntrOp<"experimental.stepvector", [0], [], [NoSideEffect], 1> { + let arguments = (ins); + let results = (outs LLVM_Type:$res); + let assemblyFormat = "attr-dict `:` type($res)"; +} + // Atomic operations. // diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -91,6 +91,12 @@ /// optimizations. void transferOpflowOpt(FuncOp func); +/// Create a cast from on index-like value (index or integer) to another +/// index-like value. If the value type and the target type are the same, it +/// returns the original value. +Value createCastToIndexLike(PatternRewriter &rewriter, Location loc, + Type targetType, Value value); + } // namespace vector } // namespace mlir diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -901,6 +901,40 @@ } }; +/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). +/// Non-scalable versions of this operation are handled in Vector Transforms. +class VectorCreateMaskOpRewritePattern + : public OpRewritePattern { +public: + explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, + bool enableIndexOpt) + : mlir::OpRewritePattern(context), + indexOptimizations(enableIndexOpt) {} + + LogicalResult matchAndRewrite(vector::CreateMaskOp op, + PatternRewriter &rewriter) const override { + auto dstType = op.getType(); + if (dstType.getRank() != 1 || !dstType.cast().isScalable()) + return failure(); + IntegerType idxType = + indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type(); + auto loc = op->getLoc(); + Value indices = rewriter.create( + loc, LLVM::getVectorType(idxType, dstType.getShape()[0], + /*isScalable=*/true)); + auto bound = + createCastToIndexLike(rewriter, loc, idxType, op.getOperand(0)); + Value bounds = rewriter.create(loc, indices.getType(), bound); + Value comp = rewriter.create(loc, arith::CmpIPredicate::slt, + indices, bounds); + rewriter.replaceOp(op, comp); + return success(); + } + +private: + const bool indexOptimizations; +}; + class VectorPrintOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -1065,13 +1099,15 @@ } // namespace /// Populate the given list with patterns that convert from Vector to LLVM. -void mlir::populateVectorToLLVMConversionPatterns( - LLVMTypeConverter &converter, RewritePatternSet &patterns, - bool reassociateFPReductions) { +void mlir::populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns, + bool reassociateFPReductions, + bool indexOptimizations) { MLIRContext *ctx = converter.getDialect()->getContext(); patterns.add(ctx); populateVectorInsertExtractStridedSliceTransforms(patterns); patterns.add(converter, reassociateFPReductions); + patterns.add(ctx, indexOptimizations); patterns .add().isScalable()) + return failure(); // Gather constant mask dimension sizes. SmallVector maskDimSizes; for (auto it : llvm::zip(createMaskOp.operands(), diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" @@ -2111,8 +2112,9 @@ } }; -static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc, - Type targetType, Value value) { +Value mlir::vector::createCastToIndexLike(PatternRewriter &rewriter, + Location loc, Type targetType, + Value value) { if (targetType == value.getType()) return value; @@ -2242,6 +2244,8 @@ LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override { auto dstType = op.getType(); + if (dstType.cast().isScalable()) + return failure(); int64_t rank = dstType.getRank(); if (rank > 1) return failure(); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir @@ -22,6 +22,27 @@ return %0 : vector<11xi1> } +// CMP32-LABEL: @genbool_var_1d_scalable( +// CMP32-SAME: %[[ARG:.*]]: index) +// CMP32: %[[T0:.*]] = llvm.intr.experimental.stepvector : vector<[11]xi32> +// CMP32: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i32 +// CMP32: %[[T2:.*]] = splat %[[T1]] : vector<[11]xi32> +// CMP32: %[[T3:.*]] = arith.cmpi slt, %[[T0]], %[[T2]] : vector<[11]xi32> +// CMP32: return %[[T3]] : vector<[11]xi1> + +// CMP64-LABEL: @genbool_var_1d_scalable( +// CMP64-SAME: %[[ARG:.*]]: index) +// CMP64: %[[T0:.*]] = llvm.intr.experimental.stepvector : vector<[11]xi64> +// CMP64: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i64 +// CMP64: %[[T2:.*]] = splat %[[T1]] : vector<[11]xi64> +// CMP64: %[[T3:.*]] = arith.cmpi slt, %[[T0]], %[[T2]] : vector<[11]xi64> +// CMP64: return %[[T3]] : vector<[11]xi1> + +func @genbool_var_1d_scalable(%arg0: index) -> vector<[11]xi1> { + %0 = vector.create_mask %arg0 : vector<[11]xi1> + return %0 : vector<[11]xi1> +} + // CMP32-LABEL: @transfer_read_1d // CMP32: %[[C:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> // CMP32: %[[A:.*]] = arith.addi %{{.*}}, %[[C]] : vector<16xi32> diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1469,6 +1469,19 @@ // CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<4xi32> // CHECK: return %[[result]] : vector<4xi1> +func @create_mask_1d_scalable(%a : index) -> vector<[4]xi1> { + %v = vector.create_mask %a : vector<[4]xi1> + return %v: vector<[4]xi1> +} + +// CHECK-LABEL: func @create_mask_1d_scalable +// CHECK-SAME: %[[arg:.*]]: index +// CHECK: %[[indices:.*]] = llvm.intr.experimental.stepvector : vector<[4]xi32> +// CHECK: %[[arg_i32:.*]] = arith.index_cast %[[arg]] : index to i32 +// CHECK: %[[bounds:.*]] = splat %[[arg_i32]] : vector<[4]xi32> +// CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<[4]xi32> +// CHECK: return %[[result]] : vector<[4]xi1> + // ----- func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {