diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -2131,7 +2131,9 @@ specifies an exclusive upper bound [0, mask-dim-size-element-value) for a unique dimension in the vector result. The conjunction of the ranges define a hyper-rectangular region within which elements values are set to 1 - (otherwise element values are set to 0). + (otherwise element values are set to 0). Each value of 'mask_dim_sizes' must + be non-negative and not greater than the size of the corresponding vector + dimension (as opposed to vector.create_mask which allows this). Example: @@ -2169,7 +2171,9 @@ each operand specifies a range [0, operand-value) for a unique dimension in the vector result. The conjunction of the operand ranges define a hyper-rectangular region within which elements values are set to 1 - (otherwise element values are set to 0). + (otherwise element values are set to 0). If operand-value is negative, it is + treated as if it were zero, and if it is greater than the corresponding + dimension size, it is treated as if it were equal to the dimension size. Example: diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -4235,9 +4235,18 @@ return failure(); // Gather constant mask dimension sizes. SmallVector maskDimSizes; - for (auto operand : createMaskOp.operands()) { - auto *defOp = operand.getDefiningOp(); - maskDimSizes.push_back(cast(defOp).value()); + for (auto it : llvm::zip(createMaskOp.operands(), + createMaskOp.getType().getShape())) { + auto *defOp = std::get<0>(it).getDefiningOp(); + int64_t maxDimSize = std::get<1>(it); + int64_t dimSize = cast(defOp).value(); + dimSize = std::min(dimSize, maxDimSize); + // If one of dim sizes is zero, set all dims to zero. + if (dimSize <= 0) { + maskDimSizes.assign(createMaskOp.getType().getRank(), 0); + break; + } + maskDimSizes.push_back(dimSize); } // Replace 'createMaskOp' with ConstantMaskOp. rewriter.replaceOpWithNewOp( diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -13,6 +13,39 @@ // ----- +// CHECK-LABEL: create_vector_mask_to_constant_mask_truncation +func @create_vector_mask_to_constant_mask_truncation() -> (vector<4x3xi1>) { + %c2 = arith.constant 2 : index + %c5 = arith.constant 5 : index + // CHECK: vector.constant_mask [4, 2] : vector<4x3xi1> + %0 = vector.create_mask %c5, %c2 : vector<4x3xi1> + return %0 : vector<4x3xi1> +} + +// ----- + +// CHECK-LABEL: create_vector_mask_to_constant_mask_truncation_neg +func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi1>) { + %cneg2 = arith.constant -2 : index + %c5 = arith.constant 5 : index + // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1> + %0 = vector.create_mask %c5, %cneg2 : vector<4x3xi1> + return %0 : vector<4x3xi1> +} + +// ----- + +// CHECK-LABEL: create_vector_mask_to_constant_mask_truncation_zero +func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3xi1>) { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1> + %0 = vector.create_mask %c0, %c2 : vector<4x3xi1> + return %0 : vector<4x3xi1> +} + +// ----- + func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) { %0 = vector.constant_mask [2, 2] : vector<4x3xi1> %1 = vector.extract_strided_slice %0 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-create-mask.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-create-mask.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-create-mask.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-create-mask.mlir @@ -4,11 +4,13 @@ // RUN: FileCheck %s func @entry() { + %cneg1 = arith.constant -1 : index %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index // // 1-D. @@ -18,16 +20,18 @@ vector.print %1 : vector<5xi1> // CHECK: ( 1, 1, 0, 0, 0 ) - scf.for %i = %c0 to %c6 step %c1 { + scf.for %i = %cneg1 to %c7 step %c1 { %2 = vector.create_mask %i : vector<5xi1> vector.print %2 : vector<5xi1> } // CHECK: ( 0, 0, 0, 0, 0 ) + // CHECK: ( 0, 0, 0, 0, 0 ) // CHECK: ( 1, 0, 0, 0, 0 ) // CHECK: ( 1, 1, 0, 0, 0 ) // CHECK: ( 1, 1, 1, 0, 0 ) // CHECK: ( 1, 1, 1, 1, 0 ) // CHECK: ( 1, 1, 1, 1, 1 ) + // CHECK: ( 1, 1, 1, 1, 1 ) // // 2-D.