Index: mlir/include/mlir/Dialect/Vector/IR/VectorOps.h =================================================================== --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -182,6 +182,11 @@ /// memory. bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB); + +/// Return the result value of reducing two scalar/vector values with the +/// corresponding arith operation. +Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, + Value v1, Value v2); } // namespace vector } // namespace mlir Index: mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h =================================================================== --- mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -34,11 +34,6 @@ /// Helper function that creates a memref::DimOp or tensor::DimOp depending on /// the type of `source`. Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim); - -/// Return the result value of reducing two scalar/vector values with the -/// corresponding arith operation. -Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, - Value v1, Value v2); } // namespace vector /// Return the number of elements of basis, `0` if empty. Index: mlir/lib/Dialect/Vector/IR/VectorOps.cpp =================================================================== --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -501,19 +501,9 @@ reductionOp.getVector(), rewriter.getI64ArrayAttr(0)); - if (Value acc = reductionOp.getAcc()) { - assert(reductionOp.getType().isa()); - switch (reductionOp.getKind()) { - case CombiningKind::ADD: - result = rewriter.create(loc, result, acc); - break; - case CombiningKind::MUL: - result = rewriter.create(loc, result, acc); - break; - default: - assert(false && "invalid op!"); - } - } + if (Value acc = reductionOp.getAcc()) + result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), + result, acc); rewriter.replaceOp(reductionOp, result); return success(); @@ -5007,6 +4997,56 @@ verifyDistributedType(lhs, rhs, getWarpSize(), getOperation())); } +Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc, + CombiningKind kind, Value v1, Value v2) { + Type t1 = getElementTypeOrSelf(v1.getType()); + Type t2 = getElementTypeOrSelf(v2.getType()); + switch (kind) { + case CombiningKind::ADD: + if (t1.isIntOrIndex() && t2.isIntOrIndex()) + return b.createOrFold(loc, v1, v2); + else if (t1.isa() && t2.isa()) + return b.createOrFold(loc, v1, v2); + llvm_unreachable("invalid value types for ADD reduction"); + case CombiningKind::AND: + assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::MAXF: + assert(t1.isa() && t2.isa() && + "expected float values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::MINF: + assert(t1.isa() && t2.isa() && + "expected float values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::MAXSI: + assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::MINSI: + assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::MAXUI: + assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::MINUI: + assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::MUL: + if (t1.isIntOrIndex() && t2.isIntOrIndex()) + return b.createOrFold(loc, v1, v2); + else if (t1.isa() && t2.isa()) + return b.createOrFold(loc, v1, v2); + llvm_unreachable("invalid value types for MUL reduction"); + case CombiningKind::OR: + assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::XOR: + assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); + return b.createOrFold(loc, v1, v2); + }; + llvm_unreachable("unknown CombiningKind"); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp =================================================================== --- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -43,56 +43,6 @@ llvm_unreachable("Expected MemRefType or TensorType"); } -Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc, - CombiningKind kind, Value v1, Value v2) { - Type t1 = getElementTypeOrSelf(v1.getType()); - Type t2 = getElementTypeOrSelf(v2.getType()); - switch (kind) { - case CombiningKind::ADD: - if (t1.isIntOrIndex() && t2.isIntOrIndex()) - return b.createOrFold(loc, v1, v2); - else if (t1.isa() && t2.isa()) - return b.createOrFold(loc, v1, v2); - llvm_unreachable("invalid value types for ADD reduction"); - case CombiningKind::AND: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); - case CombiningKind::MAXF: - assert(t1.isa() && t2.isa() && - "expected float values"); - return b.createOrFold(loc, v1, v2); - case CombiningKind::MINF: - assert(t1.isa() && t2.isa() && - "expected float values"); - return b.createOrFold(loc, v1, v2); - case CombiningKind::MAXSI: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); - case CombiningKind::MINSI: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); - case CombiningKind::MAXUI: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); - case CombiningKind::MINUI: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); - case CombiningKind::MUL: - if (t1.isIntOrIndex() && t2.isIntOrIndex()) - return b.createOrFold(loc, v1, v2); - else if (t1.isa() && t2.isa()) - return b.createOrFold(loc, v1, v2); - llvm_unreachable("invalid value types for MUL reduction"); - case CombiningKind::OR: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); - case CombiningKind::XOR: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); - }; - llvm_unreachable("unknown CombiningKind"); -} - /// Return the number of elements of basis, `0` if empty. int64_t mlir::computeMaxLinearIndex(ArrayRef basis) { if (basis.empty()) Index: mlir/test/Dialect/Vector/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Vector/canonicalize.mlir +++ mlir/test/Dialect/Vector/canonicalize.mlir @@ -1619,6 +1619,18 @@ // ----- +// CHECK-LABEL: func @reduce_one_element_vector_maxf +// CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32) +// CHECK: %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32> +// CHECK: %[[S:.+]] = arith.maxf %[[A]], %arg1 : f32 +// CHECK: return %[[S]] +func.func @reduce_one_element_vector_maxf(%a : vector<1xf32>, %b: f32) -> f32 { + %s = vector.reduction , %a, %b : vector<1xf32> into f32 + return %s : f32 +} + +// ----- + // CHECK-LABEL: func @bitcast( // CHECK-SAME: %[[ARG:.*]]: vector<4x8xf32>) -> vector<4x16xi16> { // CHECK: vector.bitcast %[[ARG:.*]] : vector<4x8xf32> to vector<4x16xi16>