diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-reductions-i4.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-reductions-i4.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Vector/CPU/test-reductions-i4.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func @entry() { + %v = std.constant dense<[-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<24xi4> + vector.print %v : vector<24xi4> + // + // Test vector: + // + // CHECK: ( -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1 ) + + + %0 = vector.reduction "add", %v : vector<24xi4> into i4 + vector.print %0 : i4 + // CHECK: 4 + + %1 = vector.reduction "mul", %v : vector<24xi4> into i4 + vector.print %1 : i4 + // CHECK: 0 + + %2 = vector.reduction "min", %v : vector<24xi4> into i4 + vector.print %2 : i4 + // CHECK: -8 + + %3 = vector.reduction "max", %v : vector<24xi4> into i4 + vector.print %3 : i4 + // CHECK: 7 + + %4 = vector.reduction "and", %v : vector<24xi4> into i4 + vector.print %4 : i4 + // CHECK: 0 + + %5 = vector.reduction "or", %v : vector<24xi4> into i4 + vector.print %5 : i4 + // CHECK: -1 + + %6 = vector.reduction "xor", %v : vector<24xi4> into i4 + vector.print %6 : i4 + // CHECK: 0 + + return +} diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-reductions-si4.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-reductions-si4.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Vector/CPU/test-reductions-si4.mlir @@ -0,0 +1,43 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func @entry() { + %v = std.constant dense<[-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7]> : vector<16xsi4> + vector.print %v : vector<16xsi4> + // + // Test vector: + // + // CHECK: ( -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7 ) + + %0 = vector.reduction "add", %v : vector<16xsi4> into si4 + vector.print %0 : si4 + // CHECK: -8 + + %1 = vector.reduction "mul", %v : vector<16xsi4> into si4 + vector.print %1 : si4 + // CHECK: 0 + + %2 = vector.reduction "min", %v : vector<16xsi4> into si4 + vector.print %2 : si4 + // CHECK: -8 + + %3 = vector.reduction "max", %v : vector<16xsi4> into si4 + vector.print %3 : si4 + // CHECK: 7 + + %4 = vector.reduction "and", %v : vector<16xsi4> into si4 + vector.print %4 : si4 + // CHECK: 0 + + %5 = vector.reduction "or", %v : vector<16xsi4> into si4 + vector.print %5 : si4 + // CHECK: -1 + + %6 = vector.reduction "xor", %v : vector<16xsi4> into si4 + vector.print %6 : si4 + // CHECK: 0 + + return +} diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-reductions-ui4.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-reductions-ui4.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Vector/CPU/test-reductions-ui4.mlir @@ -0,0 +1,43 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func @entry() { + %v = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xui4> + vector.print %v : vector<16xui4> + // + // Test vector: + // + // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 ) + + %0 = vector.reduction "add", %v : vector<16xui4> into ui4 + vector.print %0 : ui4 + // CHECK: 8 + + %1 = vector.reduction "mul", %v : vector<16xui4> into ui4 + vector.print %1 : ui4 + // CHECK: 0 + + %2 = vector.reduction "min", %v : vector<16xui4> into ui4 + vector.print %2 : ui4 + // CHECK: 0 + + %3 = vector.reduction "max", %v : vector<16xui4> into ui4 + vector.print %3 : ui4 + // CHECK: 15 + + %4 = vector.reduction "and", %v : vector<16xui4> into ui4 + vector.print %4 : ui4 + // CHECK: 0 + + %5 = vector.reduction "or", %v : vector<16xui4> into ui4 + vector.print %5 : ui4 + // CHECK: 15 + + %6 = vector.reduction "xor", %v : vector<16xui4> into ui4 + vector.print %6 : ui4 + // CHECK: 0 + + return +} diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -561,7 +561,7 @@ auto kind = reductionOp.kind(); Type eltType = reductionOp.dest().getType(); Type llvmType = typeConverter.convertType(eltType); - if (eltType.isSignlessInteger()) { + if (eltType.isIntOrIndex()) { // Integer reductions: add/mul/min/max/and/or/xor. if (kind == "add") rewriter.replaceOpWithNewOp( @@ -569,9 +569,17 @@ else if (kind == "mul") rewriter.replaceOpWithNewOp( op, llvmType, operands[0]); + else if (kind == "min" && + (eltType.isIndex() || eltType.isUnsignedInteger())) + rewriter.replaceOpWithNewOp( + op, llvmType, operands[0]); else if (kind == "min") rewriter.replaceOpWithNewOp( op, llvmType, operands[0]); + else if (kind == "max" && + (eltType.isIndex() || eltType.isUnsignedInteger())) + rewriter.replaceOpWithNewOp( + op, llvmType, operands[0]); else if (kind == "max") rewriter.replaceOpWithNewOp( op, llvmType, operands[0]); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -132,10 +132,10 @@ auto kind = op.kind(); Type eltType = op.dest().getType(); if (kind == "add" || kind == "mul" || kind == "min" || kind == "max") { - if (!eltType.isSignlessIntOrFloat()) + if (!eltType.isIntOrIndexOrFloat()) return op.emitOpError("unsupported reduction type"); } else if (kind == "and" || kind == "or" || kind == "xor") { - if (!eltType.isSignlessInteger()) + if (!eltType.isIntOrIndex()) return op.emitOpError("unsupported reduction type"); } else { return op.emitOpError("unknown reduction kind: ") << kind;