diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -561,7 +561,7 @@ auto kind = reductionOp.kind(); Type eltType = reductionOp.dest().getType(); Type llvmType = typeConverter.convertType(eltType); - if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) { + if (eltType.isSignlessInteger()) { // Integer reductions: add/mul/min/max/and/or/xor. if (kind == "add") rewriter.replaceOpWithNewOp( @@ -588,7 +588,7 @@ return failure(); return success(); - } else if (eltType.isF32() || eltType.isF64()) { + } else if (eltType.isa()) { // Floating-point reductions: add/mul/min/max if (kind == "add") { // Optional accumulator (or zero). diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -132,11 +132,10 @@ auto kind = op.kind(); Type eltType = op.dest().getType(); if (kind == "add" || kind == "mul" || kind == "min" || kind == "max") { - if (!eltType.isF32() && !eltType.isF64() && - !eltType.isSignlessInteger(32) && !eltType.isSignlessInteger(64)) + if (!eltType.isSignlessIntOrFloat()) return op.emitOpError("unsupported reduction type"); } else if (kind == "and" || kind == "or" || kind == "xor") { - if (!eltType.isSignlessInteger(32) && !eltType.isSignlessInteger(64)) + if (!eltType.isSignlessInteger()) return op.emitOpError("unsupported reduction type"); } else { return op.emitOpError("unknown reduction kind: ") << kind; @@ -146,7 +145,7 @@ if (!op.acc().empty()) { if (kind != "add" && kind != "mul") return op.emitOpError("no accumulator for reduction kind: ") << kind; - if (!eltType.isF32() && !eltType.isF64()) + if (!eltType.isa()) return op.emitOpError("no accumulator for type: ") << eltType; } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -678,6 +678,17 @@ return %0, %1: vector<8xf32>, vector<2x4xf32> } +func @reduce_f16(%arg0: vector<16xf16>) -> f16 { + %0 = vector.reduction "add", %arg0 : vector<16xf16> into f16 + return %0 : f16 +} +// CHECK-LABEL: llvm.func @reduce_f16( +// CHECK-SAME: %[[A:.*]]: !llvm.vec<16 x half>) +// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f16) : !llvm.half +// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]]) +// CHECK-SAME: {reassoc = false} : (!llvm.half, !llvm.vec<16 x half>) -> !llvm.half +// CHECK: llvm.return %[[V]] : !llvm.half + func @reduce_f32(%arg0: vector<16xf32>) -> f32 { %0 = vector.reduction "add", %arg0 : vector<16xf32> into f32 return %0 : f32 @@ -700,6 +711,15 @@ // CHECK-SAME: {reassoc = false} : (!llvm.double, !llvm.vec<16 x double>) -> !llvm.double // CHECK: llvm.return %[[V]] : !llvm.double +func @reduce_i8(%arg0: vector<16xi8>) -> i8 { + %0 = vector.reduction "add", %arg0 : vector<16xi8> into i8 + return %0 : i8 +} +// CHECK-LABEL: llvm.func @reduce_i8( +// CHECK-SAME: %[[A:.*]]: !llvm.vec<16 x i8>) +// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]]) +// CHECK: llvm.return %[[V]] : !llvm.i8 + func @reduce_i32(%arg0: vector<16xi32>) -> i32 { %0 = vector.reduction "add", %arg0 : vector<16xi32> into i32 return %0 : i32