diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -381,7 +381,8 @@ if (auto dstVecType = dyn_cast(reductionOp.getDestType())) { if (mask) { VectorType newMaskType = - VectorType::get(dstVecType.getShape(), rewriter.getI1Type()); + VectorType::get(dstVecType.getShape(), rewriter.getI1Type(), + dstVecType.getScalableDims()); mask = rewriter.create(loc, newMaskType, mask); } cast = rewriter.create( diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -8,7 +8,6 @@ %0 = vector.create_mask %c3, %c2 : vector<4x3xi1> return %0 : vector<4x3xi1> } - // ----- // CHECK-LABEL: create_scalable_vector_mask_to_constant_mask @@ -1320,6 +1319,24 @@ return %0 : vector<5x4x20xf32> } +// ----- +// CHECK-LABEL: func.func @vector_multi_reduction_scalable( +// CHECK-SAME: %[[VAL_0:.*]]: vector<1x[4]x1xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<1x[4]xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: vector<1x[4]x1xi1>) +func.func @vector_multi_reduction_scalable(%source: vector<1x[4]x1xf32>, + %acc: vector<1x[4]xf32>, + %mask: vector<1x[4]x1xi1>) -> vector<1x[4]xf32> { +// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<1x[4]x1xi1> to vector<1x[4]xi1> +// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x[4]x1xf32> to vector<1x[4]xf32> +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_1]], %[[VAL_4]] : vector<1x[4]xf32> +// CHECK: %[[VAL_6:.*]] = arith.select %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : vector<1x[4]xi1>, vector<1x[4]xf32> + %0 = vector.mask %mask { vector.multi_reduction , %source, %acc [2] : vector<1x[4]x1xf32> to vector<1x[4]xf32> } : + vector<1x[4]x1xi1> -> vector<1x[4]xf32> + + return %0 : vector<1x[4]xf32> +} + // ----- // CHECK-LABEL: func @masked_vector_multi_reduction_unit_dimensions