diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -268,8 +268,9 @@ TCresVTEtIsSameAsOpBase<0, 0>>, DeclareOpInterfaceMethods]>, - Arguments<(ins Vector_CombiningKindAttr:$kind, AnyVector:$vector, - Optional:$acc)>, + Arguments<(ins Vector_CombiningKindAttr:$kind, + AnyVectorOfAnyRank:$vector, + Optional:$acc)>, Results<(outs AnyType:$dest)> { let summary = "reduction operation"; let description = [{ diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -393,9 +393,9 @@ } LogicalResult ReductionOp::verify() { - // Verify for 1-D vector. + // Verify for 0-D and 1-D vector. int64_t rank = getVectorType().getRank(); - if (rank != 1) + if (rank > 1) return emitOpError("unsupported reduction rank: ") << rank; // Verify supported reduction kind. diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1117,6 +1117,20 @@ // ----- +func.func @reduce_0d_f32(%arg0: vector) -> f32 { + %0 = vector.reduction , %arg0 : vector into f32 + return %0 : f32 +} +// CHECK-LABEL: @reduce_0d_f32( +// CHECK-SAME: %[[A:.*]]: vector) +// CHECK: %[[CA:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector to vector<1xf32> +// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 +// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[CA]]) +// CHECK-SAME: {reassoc = false} : (f32, vector<1xf32>) -> f32 +// CHECK: return %[[V]] : f32 + +// ----- + func.func @reduce_f16(%arg0: vector<16xf16>) -> f16 { %0 = vector.reduction , %arg0 : vector<16xf16> into f16 return %0 : f16 diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -528,6 +528,30 @@ return %0 : i32 } +// CHECK-LABEL: @reduce_int +func.func @reduce_int_0d(%arg0: vector) -> i32 { + // CHECK: vector.reduction , %{{.*}} : vector into i32 + vector.reduction , %arg0 : vector into i32 + // CHECK: vector.reduction , %{{.*}} : vector into i32 + vector.reduction , %arg0 : vector into i32 + // CHECK: vector.reduction , %{{.*}} : vector into i32 + vector.reduction , %arg0 : vector into i32 + // CHECK: vector.reduction , %{{.*}} : vector into i32 + vector.reduction , %arg0 : vector into i32 + // CHECK: vector.reduction , %{{.*}} : vector into i32 + vector.reduction , %arg0 : vector into i32 + // CHECK: vector.reduction , %{{.*}} : vector into i32 + vector.reduction , %arg0 : vector into i32 + // CHECK: vector.reduction , %{{.*}} : vector into i32 + vector.reduction , %arg0 : vector into i32 + // CHECK: vector.reduction , %{{.*}} : vector into i32 + vector.reduction , %arg0 : vector into i32 + // CHECK: %[[X:.*]] = vector.reduction , %{{.*}} : vector into i32 + %0 = vector.reduction , %arg0 : vector into i32 + // CHECK: return %[[X]] : i32 + return %0 : i32 +} + // CHECK-LABEL: @transpose_fp func.func @transpose_fp(%arg0: vector<3x7xf32>) -> vector<7x3xf32> { // CHECK: %[[X:.*]] = vector.transpose %{{.*}}, [1, 0] : vector<3x7xf32> to vector<7x3xf32> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir @@ -105,6 +105,11 @@ return } +func.func @reduce_add(%arg0: vector) -> f32 { + %0 = vector.reduction , %arg0 : vector into f32 + return %0 : f32 +} + func.func @entry() { %0 = arith.constant 42.0 : f32 %1 = arith.constant dense<0.0> : vector @@ -131,5 +136,10 @@ %one_idx = arith.constant 1 : index call @create_mask_0d(%zero_idx, %one_idx) : (index, index) -> () + %red_array = arith.constant dense<5.0> : vector + %red_res = call @reduce_add(%red_array) : (vector) -> (f32) + vector.print %red_res : f32 + // CHECK: 5 + return }