diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -2111,7 +2111,7 @@ def Vector_ConstantMaskOp : Vector_Op<"constant_mask", [NoSideEffect]>, Arguments<(ins I64ArrayAttr:$mask_dim_sizes)>, - Results<(outs VectorOf<[I1]>)> { + Results<(outs VectorOfAnyRankOf<[I1]>)> { let summary = "creates a constant vector mask"; let description = [{ Creates and returns a vector mask where elements of the result vector diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -3924,8 +3924,19 @@ //===----------------------------------------------------------------------===// static LogicalResult verify(ConstantMaskOp &op) { - // Verify that array attr size matches the rank of the vector result. auto resultType = op.getResult().getType().cast(); + // Check the corner case of 0-D vectors first. + if (resultType.getRank() == 0) { + if (op.mask_dim_sizes().size() != 1) + return op->emitError("array attr must have length 1 for 0-D vectors"); + auto dim = op.mask_dim_sizes()[0].cast().getInt(); + if (dim != 0 && dim != 1) + return op->emitError( + "mask dim size must be either 0 or 1 for 0-D vectors"); + return success(); + } + + // Verify that array attr size matches the rank of the vector result. if (static_cast(op.mask_dim_sizes().size()) != resultType.getRank()) return op.emitOpError( "must specify array attr of size equal vector result rank"); diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -960,7 +960,20 @@ auto dstType = op.getType(); auto eltType = dstType.getElementType(); auto dimSizes = op.mask_dim_sizes(); - int64_t rank = dimSizes.size(); + int64_t rank = dstType.getRank(); + + if (rank == 0) { + assert(dimSizes.size() == 1 && + "Expected exactly one dim size for a 0-D vector"); + bool value = dimSizes[0].cast().getInt() == 1; + rewriter.replaceOpWithNewOp( + op, dstType, + DenseIntElementsAttr::get( + VectorType::get(ArrayRef{}, rewriter.getI1Type()), + ArrayRef{value})); + return success(); + } + int64_t trueDim = std::min(dstType.getDimSize(0), dimSizes[0].cast().getInt()); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1396,6 +1396,26 @@ // ----- +func @genbool_0d_f() -> vector { + %0 = vector.constant_mask [0] : vector + return %0 : vector +} +// CHECK-LABEL: func @genbool_0d_f +// CHECK: %[[VAL_0:.*]] = arith.constant dense : vector +// CHECK: return %[[VAL_0]] : vector + +// ----- + +func @genbool_0d_t() -> vector { + %0 = vector.constant_mask [1] : vector + return %0 : vector +} +// CHECK-LABEL: func @genbool_0d_t +// CHECK: %[[VAL_0:.*]] = arith.constant dense : vector +// CHECK: return %[[VAL_0]] : vector + +// ----- + func @genbool_1d() -> vector<8xi1> { %0 = vector.constant_mask [4] : vector<8xi1> return %0 : vector<8xi1> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -882,6 +882,20 @@ } +// ----- + +func @constant_mask_0d_no_attr() { + // expected-error@+1 {{array attr must have length 1 for 0-D vectors}} + %0 = vector.constant_mask [] : vector +} + +// ----- + +func @constant_mask_0d_bad_attr() { + // expected-error@+1 {{mask dim size must be either 0 or 1 for 0-D vectors}} + %0 = vector.constant_mask [2] : vector +} + // ----- func @constant_mask() { diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -376,6 +376,15 @@ return } +// CHECK-LABEL: @constant_vector_mask_0d +func @constant_vector_mask_0d() { + // CHECK: vector.constant_mask [0] : vector + %0 = vector.constant_mask [0] : vector + // CHECK: vector.constant_mask [1] : vector + %1 = vector.constant_mask [1] : vector + return +} + // CHECK-LABEL: @constant_vector_mask func @constant_vector_mask() { // CHECK: vector.constant_mask [3, 2] : vector<4x3xi1> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir @@ -68,6 +68,16 @@ } +func @constant_mask_0d() { + %1 = vector.constant_mask [0] : vector + // CHECK: ( 0 ) + vector.print %1: vector + %2 = vector.constant_mask [1] : vector + // CHECK: ( 1 ) + vector.print %2: vector + return +} + func @entry() { %0 = arith.constant 42.0 : f32 %1 = arith.constant dense<0.0> : vector @@ -78,10 +88,13 @@ call @print_vector_0d(%3) : (vector) -> () %4 = arith.constant 42.0 : f32 + + // Warning: these must be called in their textual order of definition in the + // file to not mess up FileCheck. call @splat_0d(%4) : (f32) -> () call @broadcast_0d(%4) : (f32) -> () - call @bitcast_0d() : () -> () + call @constant_mask_0d() : () -> () return }