diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -2149,7 +2149,8 @@ def Vector_CreateMaskOp : Vector_Op<"create_mask", [NoSideEffect]>, - Arguments<(ins Variadic:$operands)>, Results<(outs VectorOf<[I1]>)> { + Arguments<(ins Variadic:$operands)>, + Results<(outs VectorOfAnyRankOf<[I1]>)> { let summary = "creates a vector mask"; let description = [{ Creates and returns a vector mask where elements of the result vector diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -3968,11 +3968,17 @@ //===----------------------------------------------------------------------===// static LogicalResult verify(CreateMaskOp op) { + auto vectorType = op.getResult().getType().cast(); // Verify that an operand was specified for each result vector each dimension. - if (op.getNumOperands() != - op.getResult().getType().cast().getRank()) + if (vectorType.getRank() == 0) { + if (op->getNumOperands() != 1) + return op.emitOpError( + "must specify exactly one operand for 0-D create_mask"); + } else if (op.getNumOperands() != + op.getResult().getType().cast().getRank()) { return op.emitOpError( "must specify an operand for each result vector dimension"); + } return success(); } diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -677,16 +677,17 @@ LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override { - auto loc = op.getLoc(); auto dstType = op.getResult().getType().cast(); + int64_t rank = dstType.getRank(); + if (rank <= 1) + return rewriter.notifyMatchFailure( + op, "0-D and 1-D vectors are handled separately"); + + auto loc = op.getLoc(); auto eltType = dstType.getElementType(); int64_t dim = dstType.getDimSize(0); - int64_t rank = dstType.getRank(); Value idx = op.getOperand(0); - if (rank == 1) - return failure(); // leave for lowering - VectorType lowType = VectorType::get(dstType.getShape().drop_front(), eltType); Value trueVal = rewriter.create( @@ -2717,6 +2718,8 @@ // Helper that returns a vector comparison that constructs a mask: // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] // +// If `dim == 0` then the result will be a 0-D vector. +// // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, // much more compact, IR for this operation, but LLVM eventually // generates more elaborate instructions for this intrinsic since it @@ -2728,19 +2731,23 @@ // If we can assume all indices fit in 32-bit, we perform the vector // comparison in 32-bit to get a higher degree of SIMD parallelism. // Otherwise we perform the vector comparison using 64-bit indices. - Value indices; - Type idxType; - if (indexOptimizations) { - indices = rewriter.create( - loc, rewriter.getI32VectorAttr( - llvm::to_vector<4>(llvm::seq(0, dim)))); - idxType = rewriter.getI32Type(); + Type idxType = + indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type(); + DenseIntElementsAttr indicesAttr; + if (dim == 0 && indexOptimizations) { + indicesAttr = DenseIntElementsAttr::get( + VectorType::get(ArrayRef{}, idxType), ArrayRef{0}); + } else if (dim == 0) { + indicesAttr = DenseIntElementsAttr::get( + VectorType::get(ArrayRef{}, idxType), ArrayRef{0}); + } else if (indexOptimizations) { + indicesAttr = rewriter.getI32VectorAttr( + llvm::to_vector<4>(llvm::seq(0, dim))); } else { - indices = rewriter.create( - loc, rewriter.getI64VectorAttr( - llvm::to_vector<4>(llvm::seq(0, dim)))); - idxType = rewriter.getI64Type(); + indicesAttr = rewriter.getI64VectorAttr( + llvm::to_vector<4>(llvm::seq(0, dim))); } + Value indices = rewriter.create(loc, indicesAttr); // Add in an offset if requested. if (off) { Value o = createCastToIndexLike(rewriter, loc, idxType, *off); @@ -2806,7 +2813,7 @@ const bool indexOptimizations; }; -/// Conversion pattern for a vector.create_mask (1-D only). +/// Conversion pattern for a `vector.create_mask` (0-D and 1-D only). class VectorCreateMaskOpConversion : public OpRewritePattern { public: @@ -2819,13 +2826,13 @@ PatternRewriter &rewriter) const override { auto dstType = op.getType(); int64_t rank = dstType.getRank(); - if (rank == 1) { - rewriter.replaceOp( - op, buildVectorComparison(rewriter, op, indexOptimizations, - dstType.getDimSize(0), op.getOperand(0))); - return success(); - } - return failure(); + if (rank > 1) + return failure(); + rewriter.replaceOp( + op, buildVectorComparison(rewriter, op, indexOptimizations, + rank == 0 ? 0 : dstType.getDimSize(0), + op.getOperand(0))); + return success(); } private: diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1442,6 +1442,35 @@ // ----- +func @create_mask_0d(%a : index) -> vector { + %v = vector.create_mask %a : vector + return %v: vector +} + +// CHECK-LABEL: func @create_mask_0d +// CHECK-SAME: %[[arg:.*]]: index +// CHECK: %[[indices:.*]] = arith.constant dense<0> : vector +// CHECK: %[[arg_i32:.*]] = arith.index_cast %[[arg]] : index to i32 +// CHECK: %[[bounds:.*]] = splat %[[arg_i32]] : vector +// CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector +// CHECK: return %[[result]] : vector +// ----- + +func @create_mask_1d(%a : index) -> vector<4xi1> { + %v = vector.create_mask %a : vector<4xi1> + return %v: vector<4xi1> +} + +// CHECK-LABEL: func @create_mask_1d +// CHECK-SAME: %[[arg:.*]]: index +// CHECK: %[[indices:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32> +// CHECK: %[[arg_i32:.*]] = arith.index_cast %[[arg]] : index to i32 +// CHECK: %[[bounds:.*]] = splat %[[arg_i32]] : vector<4xi32> +// CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<4xi32> +// CHECK: return %[[result]] : vector<4xi1> + +// ----- + func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> { %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } : vector<16xf32> -> vector<16xf32> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -874,6 +874,24 @@ // ----- +func @create_mask_0d_no_operands() { + %c1 = arith.constant 1 : index + // expected-error@+1 {{must specify exactly one operand for 0-D create_mask}} + %0 = vector.create_mask : vector +} + +// ----- + +func @create_mask_0d_many_operands() { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + // expected-error@+1 {{must specify exactly one operand for 0-D create_mask}} + %0 = vector.create_mask %c1, %c2, %c3 : vector +} + +// ----- + func @create_mask() { %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir @@ -93,6 +93,18 @@ return } +func @create_mask_0d(%zero : index, %one : index) { + %zero_mask = vector.create_mask %zero : vector + // CHECK: ( 0 ) + vector.print %zero_mask : vector + + %one_mask = vector.create_mask %one : vector + // CHECK: ( 1 ) + vector.print %one_mask : vector + + return +} + func @entry() { %0 = arith.constant 42.0 : f32 %1 = arith.constant dense<0.0> : vector @@ -115,5 +127,9 @@ %bigger = arith.constant dense<4242> : vector call @arith_cmpi_0d(%smaller, %bigger) : (vector, vector) -> () + %zero_idx = arith.constant 0 : index + %one_idx = arith.constant 1 : index + call @create_mask_0d(%zero_idx, %one_idx) : (index, index) -> () + return }