diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -63,9 +63,10 @@ RewritePatternSet &patterns); /// Collect a set of patterns to convert from the Vector dialect to LLVM. +/// If `indexOptimizations` is set, assume indices fit in 32-bit. void populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns, - bool reassociateFPReductions = false); + bool reassociateFPReductions = false, bool indexOptimizations = false); /// Create a pass to convert vector operations to the LLVMIR dialect. std::unique_ptr> createConvertVectorToLLVMPass( diff --git a/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h b/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h @@ -80,6 +80,12 @@ Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr); +/// Create a cast from an index-like value (index or integer) to another +/// index-like value. If the value type and the target type are the same, it +/// returns the original value. +Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, + Type targetType, Value value); + /// Similar to the other overload, but converts multiple OpFoldResults into /// Values. SmallVector diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1752,6 +1752,14 @@ /// Create a call to vscale intrinsic. def LLVM_vscale : LLVM_IntrOp<"vscale", [0], [], [], 1>; +/// Create a call to stepvector intrinsic. +def LLVM_StepVectorOp + : LLVM_IntrOp<"experimental.stepvector", [0], [], [NoSideEffect], 1> { + let arguments = (ins); + let results = (outs LLVM_Type:$res); + let assemblyFormat = "attr-dict `:` type($res)"; +} + // Atomic operations. // diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -900,6 +901,40 @@ } }; +/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). +/// Non-scalable versions of this operation are handled in Vector Transforms. +class VectorCreateMaskOpRewritePattern + : public OpRewritePattern { +public: + explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, + bool enableIndexOpt) + : OpRewritePattern(context), + indexOptimizations(enableIndexOpt) {} + + LogicalResult matchAndRewrite(vector::CreateMaskOp op, + PatternRewriter &rewriter) const override { + auto dstType = op.getType(); + if (dstType.getRank() != 1 || !dstType.cast().isScalable()) + return failure(); + IntegerType idxType = + indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type(); + auto loc = op->getLoc(); + Value indices = rewriter.create( + loc, LLVM::getVectorType(idxType, dstType.getShape()[0], + /*isScalable=*/true)); + auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, + op.getOperand(0)); + Value bounds = rewriter.create(loc, indices.getType(), bound); + Value comp = rewriter.create(loc, arith::CmpIPredicate::slt, + indices, bounds); + rewriter.replaceOp(op, comp); + return success(); + } + +private: + const bool indexOptimizations; +}; + class VectorPrintOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -1157,13 +1192,15 @@ } // namespace /// Populate the given list with patterns that convert from Vector to LLVM. -void mlir::populateVectorToLLVMConversionPatterns( - LLVMTypeConverter &converter, RewritePatternSet &patterns, - bool reassociateFPReductions) { +void mlir::populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns, + bool reassociateFPReductions, + bool indexOptimizations) { MLIRContext *ctx = converter.getDialect()->getContext(); patterns.add(ctx); populateVectorInsertExtractStridedSliceTransforms(patterns); patterns.add(converter, reassociateFPReductions); + patterns.add(ctx, indexOptimizations); patterns .add(loc, attr.getValue().getSExtValue()); } +Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, + Type targetType, Value value) { + if (targetType == value.getType()) + return value; + + bool targetIsIndex = targetType.isIndex(); + bool valueIsIndex = value.getType().isIndex(); + if (targetIsIndex ^ valueIsIndex) + return b.create(loc, targetType, value); + + auto targetIntegerType = targetType.dyn_cast(); + auto valueIntegerType = value.getType().dyn_cast(); + assert(targetIntegerType && valueIntegerType && + "unexpected cast between types other than integers and index"); + assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); + + if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) + return b.create(loc, targetIntegerType, value); + return b.create(loc, targetIntegerType, value); +} + SmallVector mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, ArrayRef valueOrAttrVec) { diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4232,6 +4232,14 @@ if (anyZeros && !allZeros) return emitOpError("expected all mask dim sizes to be zeros, " "as a result of conjunction with zero mask dim"); + // Verify that if the mask type is scalable, dimensions should be zero because + // constant scalable masks can only be defined for the "none set" or "all set" + // cases, and there is no VLA way to define an "all set" case for + // `vector.constant_mask`. In the future, a convention could be established + // to decide if a specific dimension value could be considered as "all set". + if (resultType.isScalable() && + mask_dim_sizes()[0].cast().getInt() != 0) + return emitOpError("expected mask dim sizes for scalable masks to be 0"); return success(); } @@ -4269,6 +4277,19 @@ }; if (llvm::any_of(createMaskOp.operands(), isNotDefByConstant)) return failure(); + + // CreateMaskOp for scalable vectors can be folded only if all dimensions + // are negative or zero. + if (auto vType = createMaskOp.getType().dyn_cast()) { + if (vType.isScalable()) + for (auto opDim : createMaskOp.getOperands()) { + APInt intVal; + if (matchPattern(opDim, m_ConstantInt(&intVal)) && + intVal.isStrictlyPositive()) + return failure(); + } + } + // Gather constant mask dimension sizes. SmallVector maskDimSizes; for (auto it : llvm::zip(createMaskOp.operands(), diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -16,6 +16,8 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" @@ -602,6 +604,13 @@ return success(); } + // Scalable constant masks can only be lowered for the "none set" case. + if (dstType.cast().isScalable()) { + rewriter.replaceOpWithNewOp( + op, DenseElementsAttr::get(dstType, false)); + return success(); + } + int64_t trueDim = std::min(dstType.getDimSize(0), dimSizes[0].cast().getInt()); @@ -2161,27 +2170,6 @@ } }; -static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc, - Type targetType, Value value) { - if (targetType == value.getType()) - return value; - - bool targetIsIndex = targetType.isIndex(); - bool valueIsIndex = value.getType().isIndex(); - if (targetIsIndex ^ valueIsIndex) - return rewriter.create(loc, targetType, value); - - auto targetIntegerType = targetType.dyn_cast(); - auto valueIntegerType = value.getType().dyn_cast(); - assert(targetIntegerType && valueIntegerType && - "unexpected cast between types other than integers and index"); - assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); - - if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) - return rewriter.create(loc, targetIntegerType, value); - return rewriter.create(loc, targetIntegerType, value); -} - // Helper that returns a vector comparison that constructs a mask: // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] // @@ -2217,12 +2205,12 @@ Value indices = rewriter.create(loc, indicesAttr); // Add in an offset if requested. if (off) { - Value o = createCastToIndexLike(rewriter, loc, idxType, *off); + Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off); Value ov = rewriter.create(loc, indices.getType(), o); indices = rewriter.create(loc, ov, indices); } // Construct the vector comparison. - Value bound = createCastToIndexLike(rewriter, loc, idxType, b); + Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b); Value bounds = rewriter.create(loc, indices.getType(), bound); return rewriter.create(loc, arith::CmpIPredicate::slt, indices, @@ -2292,6 +2280,8 @@ LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override { auto dstType = op.getType(); + if (dstType.cast().isScalable()) + return failure(); int64_t rank = dstType.getRank(); if (rank > 1) return failure(); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir @@ -24,6 +24,29 @@ return %0 : vector<11xi1> } +// CMP32-LABEL: @genbool_var_1d_scalable( +// CMP32-SAME: %[[ARG:.*]]: index) +// CMP32: %[[T0:.*]] = llvm.intr.experimental.stepvector : vector<[11]xi32> +// CMP32: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i32 +// CMP32: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<[11]xi32> +// CMP32: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<[11]xi32>, vector<[11]xi32> +// CMP32: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<[11]xi32> +// CMP32: return %[[T4]] : vector<[11]xi1> + +// CMP64-LABEL: @genbool_var_1d_scalable( +// CMP64-SAME: %[[ARG:.*]]: index) +// CMP64: %[[T0:.*]] = llvm.intr.experimental.stepvector : vector<[11]xi64> +// CMP64: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i64 +// CMP64: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<[11]xi64> +// CMP64: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<[11]xi64>, vector<[11]xi64> +// CMP64: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<[11]xi64> +// CMP64: return %[[T4]] : vector<[11]xi1> + +func @genbool_var_1d_scalable(%arg0: index) -> vector<[11]xi1> { + %0 = vector.create_mask %arg0 : vector<[11]xi1> + return %0 : vector<[11]xi1> +} + // CMP32-LABEL: @transfer_read_1d // CMP32: %[[MEM:.*]]: memref, %[[OFF:.*]]: index) -> vector<16xf32> { // CMP32: %[[D:.*]] = memref.dim %[[MEM]], %{{.*}} : memref diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1459,6 +1459,16 @@ // ----- +func @genbool_1d_scalable() -> vector<[8]xi1> { + %0 = vector.constant_mask [0] : vector<[8]xi1> + return %0 : vector<[8]xi1> +} +// CHECK-LABEL: func @genbool_1d_scalable +// CHECK: %[[VAL_0:.*]] = arith.constant dense : vector<[8]xi1> +// CHECK: return %[[VAL_0]] : vector<[8]xi1> + +// ----- + func @genbool_2d() -> vector<4x4xi1> { %v = vector.constant_mask [2, 2] : vector<4x4xi1> return %v: vector<4x4xi1> @@ -1505,6 +1515,20 @@ // CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<4xi32> // CHECK: return %[[result]] : vector<4xi1> +func @create_mask_1d_scalable(%a : index) -> vector<[4]xi1> { + %v = vector.create_mask %a : vector<[4]xi1> + return %v: vector<[4]xi1> +} + +// CHECK-LABEL: func @create_mask_1d_scalable +// CHECK-SAME: %[[arg:.*]]: index +// CHECK: %[[indices:.*]] = llvm.intr.experimental.stepvector : vector<[4]xi32> +// CHECK: %[[arg_i32:.*]] = arith.index_cast %[[arg]] : index to i32 +// CHECK: %[[boundsInsert:.*]] = llvm.insertelement %[[arg_i32]], {{.*}} : vector<[4]xi32> +// CHECK: %[[bounds:.*]] = llvm.shufflevector %[[boundsInsert]], {{.*}} : vector<[4]xi32>, vector<[4]xi32> +// CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<[4]xi32> +// CHECK: return %[[result]] : vector<[4]xi1> + // ----- func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> { diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -13,6 +13,16 @@ // ----- +// CHECK-LABEL: create_scalable_vector_mask_to_constant_mask +func @create_scalable_vector_mask_to_constant_mask() -> (vector<[8]xi1>) { + %c-1 = arith.constant -1 : index + // CHECK: vector.constant_mask [0] : vector<[8]xi1> + %0 = vector.create_mask %c-1 : vector<[8]xi1> + return %0 : vector<[8]xi1> +} + +// ----- + // CHECK-LABEL: create_vector_mask_to_constant_mask_truncation func @create_vector_mask_to_constant_mask_truncation() -> (vector<4x3xi1>) { %c2 = arith.constant 2 : index diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -944,6 +944,13 @@ // ----- +func @constant_mask_scalable_non_zero_dim_size() { + // expected-error@+1 {{expected mask dim sizes for scalable masks to be 0}} + %0 = vector.constant_mask [2] : vector<[8]xi1> +} + +// ----- + func @print_no_result(%arg0 : f32) -> i32 { // expected-error@+1 {{cannot name an operation with no results}} %0 = vector.print %arg0 : f32 diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -389,6 +389,8 @@ func @constant_vector_mask() { // CHECK: vector.constant_mask [3, 2] : vector<4x3xi1> %0 = vector.constant_mask [3, 2] : vector<4x3xi1> + // CHECK: vector.constant_mask [0] : vector<[4]xi1> + %1 = vector.constant_mask [0] : vector<[4]xi1> return }