diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -91,7 +91,7 @@ Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc, Variadic>:$masks, Vector_AffineMapArrayAttr:$indexing_maps, - ArrayAttr:$iterator_types, + ArrayAttr:$iterator_types, DefaultValuedAttr:$kind)>, Results<(outs AnyType)> { @@ -280,8 +280,7 @@ let description = [{ Reduces an 1-D vector "horizontally" into a scalar using the given operation (add/mul/min/max for int/fp and and/or/xor for int only). - Some reductions (add/mul for fp) also allow an optional fused - accumulator. + Reductions also allow an optional fused accumulator. Note that these operations are restricted to 1-D vectors to remain close to the corresponding LLVM intrinsics: @@ -1760,7 +1759,7 @@ Vector_Op<"gather">, Arguments<(ins Arg:$base, Variadic:$indices, - VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec, + VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec, VectorOfRankAndType<[1], [I1]>:$mask, VectorOfRank<[1]>:$pass_thru)>, Results<(outs VectorOfRank<[1]>:$result)> { @@ -1826,7 +1825,7 @@ Vector_Op<"scatter">, Arguments<(ins Arg:$base, Variadic:$indices, - VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec, + VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec, VectorOfRankAndType<[1], [I1]>:$mask, VectorOfRank<[1]>:$valueToStore)> { diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -362,6 +362,37 @@ } }; +/// Helper method to lower a `vector.reduction` op that performs an arithmetic +/// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use +/// and `ScalarOp` is the scalar operation used to add the accumulation value if +/// non-null. +template +static Value createIntegerReductionArithmeticOpLowering( + ConversionPatternRewriter &rewriter, Location loc, Type llvmType, + Value vectorOperand, Value accumulator) { + Value result = rewriter.create(loc, llvmType, vectorOperand); + if (accumulator) + result = rewriter.create(loc, accumulator, result); + return result; +} + +/// Helper method to lower a `vector.reduction` operation that performs +/// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector +/// intrinsic to use and `predicate` is the predicate to use to compare+combine +/// the accumulator value if non-null. +template +static Value createIntegerReductionComparisonOpLowering( + ConversionPatternRewriter &rewriter, Location loc, Type llvmType, + Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { + Value result = rewriter.create(loc, llvmType, vectorOperand); + if (accumulator) { + Value cmp = + rewriter.create(loc, predicate, accumulator, result); + result = rewriter.create(loc, cmp, accumulator, result); + } + return result; +} + /// Conversion pattern for all vector reductions. class VectorReductionOpConversion : public ConvertOpToLLVMPattern { @@ -377,38 +408,68 @@ auto kind = reductionOp.getKind(); Type eltType = reductionOp.getDest().getType(); Type llvmType = typeConverter->convertType(eltType); - Value operand = adaptor.getOperands()[0]; + Value operand = adaptor.getVector(); + Value acc = adaptor.getAcc(); + Location loc = reductionOp.getLoc(); if (eltType.isIntOrIndex()) { // Integer reductions: add/mul/min/max/and/or/xor. - if (kind == vector::CombiningKind::ADD) - rewriter.replaceOpWithNewOp(reductionOp, - llvmType, operand); - else if (kind == vector::CombiningKind::MUL) - rewriter.replaceOpWithNewOp(reductionOp, - llvmType, operand); - else if (kind == vector::CombiningKind::MINUI) - rewriter.replaceOpWithNewOp( - reductionOp, llvmType, operand); - else if (kind == vector::CombiningKind::MINSI) - rewriter.replaceOpWithNewOp( - reductionOp, llvmType, operand); - else if (kind == vector::CombiningKind::MAXUI) - rewriter.replaceOpWithNewOp( - reductionOp, llvmType, operand); - else if (kind == vector::CombiningKind::MAXSI) - rewriter.replaceOpWithNewOp( - reductionOp, llvmType, operand); - else if (kind == vector::CombiningKind::AND) - rewriter.replaceOpWithNewOp(reductionOp, - llvmType, operand); - else if (kind == vector::CombiningKind::OR) - rewriter.replaceOpWithNewOp(reductionOp, - llvmType, operand); - else if (kind == vector::CombiningKind::XOR) - rewriter.replaceOpWithNewOp(reductionOp, - llvmType, operand); - else + Value result; + switch (kind) { + case vector::CombiningKind::ADD: + result = + createIntegerReductionArithmeticOpLowering( + rewriter, loc, llvmType, operand, acc); + break; + case vector::CombiningKind::MUL: + result = + createIntegerReductionArithmeticOpLowering( + rewriter, loc, llvmType, operand, acc); + break; + case vector::CombiningKind::MINUI: + result = createIntegerReductionComparisonOpLowering< + LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, + LLVM::ICmpPredicate::ule); + break; + case vector::CombiningKind::MINSI: + result = createIntegerReductionComparisonOpLowering< + LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, + LLVM::ICmpPredicate::sle); + break; + case vector::CombiningKind::MAXUI: + result = createIntegerReductionComparisonOpLowering< + LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, + LLVM::ICmpPredicate::uge); + break; + case vector::CombiningKind::MAXSI: + result = createIntegerReductionComparisonOpLowering< + LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, + LLVM::ICmpPredicate::sge); + break; + case vector::CombiningKind::AND: + result = + createIntegerReductionArithmeticOpLowering( + rewriter, loc, llvmType, operand, acc); + break; + case vector::CombiningKind::OR: + result = + createIntegerReductionArithmeticOpLowering( + rewriter, loc, llvmType, operand, acc); + break; + case vector::CombiningKind::XOR: + result = + createIntegerReductionArithmeticOpLowering( + rewriter, loc, llvmType, operand, acc); + break; + default: return failure(); + } + rewriter.replaceOp(reductionOp, result); + return success(); } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -403,15 +403,6 @@ << eltType << "' for kind '" << stringifyCombiningKind(getKind()) << "'"; - // Verify optional accumulator. - if (getAcc()) { - if (getKind() != CombiningKind::ADD && getKind() != CombiningKind::MUL) - return emitOpError("no accumulator for reduction kind: ") - << stringifyCombiningKind(getKind()); - if (!eltType.isa()) - return emitOpError("no accumulator for type: ") << eltType; - } - return success(); } @@ -1969,7 +1960,7 @@ (static_cast(srcVectorType.getRank()) + positionAttr.size() != static_cast(destVectorType.getRank()))) return emitOpError("expected position attribute rank + source rank to " - "match dest vector rank"); + "match dest vector rank"); if (!srcVectorType && (positionAttr.size() != static_cast(destVectorType.getRank()))) return emitOpError( @@ -2302,8 +2293,7 @@ int64_t numFixedVectorSizes = fixedVectorSizes.size(); if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes) - return emitError("invalid input shape for vector type ") - << inputVectorType; + return emitError("invalid input shape for vector type ") << inputVectorType; if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes) return emitError("invalid output shape for vector type ") @@ -2396,24 +2386,29 @@ auto sizes = getSizesAttr(); auto strides = getStridesAttr(); if (offsets.size() != sizes.size() || offsets.size() != strides.size()) - return emitOpError("expected offsets, sizes and strides attributes of same size"); + return emitOpError( + "expected offsets, sizes and strides attributes of same size"); auto shape = type.getShape(); auto offName = getOffsetsAttrName(); auto sizesName = getSizesAttrName(); auto stridesName = getStridesAttrName(); - if (failed(isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) || - failed(isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) || + if (failed( + isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) || + failed( + isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) || failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape, stridesName)) || - failed(isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) || + failed( + isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) || failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName, /*halfOpen=*/false, /*min=*/1)) || - failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, stridesName, + failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, + stridesName, /*halfOpen=*/false)) || - failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes, shape, - offName, sizesName, + failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes, + shape, offName, sizesName, /*halfOpen=*/false))) return failure(); @@ -4223,7 +4218,7 @@ if (sourceVectorType.getRank() == 0) { if (sourceElementBits != resultElementBits) return emitOpError("source/result bitwidth of the 0-D vector element " - "types must be equal"); + "types must be equal"); } else if (sourceElementBits * sourceVectorType.getShape().back() != resultElementBits * resultVectorType.getShape().back()) { return emitOpError( diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1875,10 +1875,9 @@ assert(rhsType.getRank() == 1 && "corrupt contraction"); Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter); auto kind = vector::CombiningKind::ADD; - Value res = rewriter.create(loc, kind, m); if (auto acc = op.getAcc()) - res = createAdd(op.getLoc(), res, acc, isInt, rewriter); - return res; + return rewriter.create(loc, kind, m, acc); + return rewriter.create(loc, kind, m); } // Construct new iterator types and affine map array attribute. std::array lowIndexingMaps = { diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1178,6 +1178,206 @@ // ----- +func.func @reduce_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 { + %0 = vector.reduction , %arg0, %arg1 : vector<16xi32> into i32 + return %0 : i32 +} +// CHECK-LABEL: @reduce_acc_i32( +// CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32) +// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.add"(%[[A]]) +// CHECK: %[[V:.*]] = llvm.add %[[ACC]], %[[R]] +// CHECK: return %[[V]] : i32 + +// ----- + +func.func @reduce_mul_i32(%arg0: vector<16xi32>) -> i32 { + %0 = vector.reduction , %arg0 : vector<16xi32> into i32 + return %0 : i32 +} +// CHECK-LABEL: @reduce_mul_i32( +// CHECK-SAME: %[[A:.*]]: vector<16xi32>) +// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.mul"(%[[A]]) +// CHECK: return %[[V]] : i32 + +// ----- + +func.func @reduce_mul_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 { + %0 = vector.reduction , %arg0, %arg1 : vector<16xi32> into i32 + return %0 : i32 +} +// CHECK-LABEL: @reduce_mul_acc_i32( +// CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32) +// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.mul"(%[[A]]) +// CHECK: %[[V:.*]] = llvm.mul %[[ACC]], %[[R]] +// CHECK: return %[[V]] : i32 + +// ----- + +func.func @reduce_minui_i32(%arg0: vector<16xi32>) -> i32 { + %0 = vector.reduction , %arg0 : vector<16xi32> into i32 + return %0 : i32 +} +// CHECK-LABEL: @reduce_minui_i32( +// CHECK-SAME: %[[A:.*]]: vector<16xi32>) +// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.umin"(%[[A]]) +// CHECK: return %[[V]] : i32 + +// ----- + +func.func @reduce_minui_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 { + %0 = vector.reduction , %arg0, %arg1 : vector<16xi32> into i32 + return %0 : i32 +} +// CHECK-LABEL: @reduce_minui_acc_i32( +// CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32) +// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.umin"(%[[A]]) +// CHECK: %[[S:.*]] = llvm.icmp "ule" %[[ACC]], %[[R]] +// CHECK: %[[V:.*]] = llvm.select %[[S]], %[[ACC]], %[[R]] +// CHECK: return %[[V]] : i32 + +// ----- + +func.func @reduce_maxui_i32(%arg0: vector<16xi32>) -> i32 { + %0 = vector.reduction , %arg0 : vector<16xi32> into i32 + return %0 : i32 +} +// CHECK-LABEL: @reduce_maxui_i32( +// CHECK-SAME: %[[A:.*]]: vector<16xi32>) +// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.umax"(%[[A]]) +// CHECK: return %[[V]] : i32 + +// ----- + +func.func @reduce_maxui_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 { + %0 = vector.reduction , %arg0, %arg1 : vector<16xi32> into i32 + return %0 : i32 +} +// CHECK-LABEL: @reduce_maxui_acc_i32( +// CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32) +// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.umax"(%[[A]]) +// CHECK: %[[S:.*]] = llvm.icmp "uge" %[[ACC]], %[[R]] +// CHECK: %[[V:.*]] = llvm.select %[[S]], %[[ACC]], %[[R]] +// CHECK: return %[[V]] : i32 + +// ----- + +func.func @reduce_minsi_i32(%arg0: vector<16xi32>) -> i32 { + %0 = vector.reduction , %arg0 : vector<16xi32> into i32 + return %0 : i32 +} +// CHECK-LABEL: @reduce_minsi_i32( +// CHECK-SAME: %[[A:.*]]: vector<16xi32>) +// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.smin"(%[[A]]) +// CHECK: return %[[V]] : i32 + +// ----- + +func.func @reduce_minsi_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 { + %0 = vector.reduction , %arg0, %arg1 : vector<16xi32> into i32 + return %0 : i32 +} +// CHECK-LABEL: @reduce_minsi_acc_i32( +// CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32) +// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.smin"(%[[A]]) +// CHECK: %[[S:.*]] = llvm.icmp "sle" %[[ACC]], %[[R]] +// CHECK: %[[V:.*]] = llvm.select %[[S]], %[[ACC]], %[[R]] +// CHECK: return %[[V]] : i32 + +// ----- + +func.func @reduce_maxsi_i32(%arg0: vector<16xi32>) -> i32 { + %0 = vector.reduction , %arg0 : vector<16xi32> into i32 + return %0 : i32 +} +// CHECK-LABEL: @reduce_maxsi_i32( +// CHECK-SAME: %[[A:.*]]: vector<16xi32>) +// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.smax"(%[[A]]) +// CHECK: return %[[V]] : i32 + +// ----- + +func.func @reduce_maxsi_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 { + %0 = vector.reduction , %arg0, %arg1 : vector<16xi32> into i32 + return %0 : i32 +} +// CHECK-LABEL: @reduce_maxsi_acc_i32( +// CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32) +// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.smax"(%[[A]]) +// CHECK: %[[S:.*]] = llvm.icmp "sge" %[[ACC]], %[[R]] +// CHECK: %[[V:.*]] = llvm.select %[[S]], %[[ACC]], %[[R]] +// CHECK: return %[[V]] : i32 + +// ----- + +func.func @reduce_and_i32(%arg0: vector<16xi32>) -> i32 { + %0 = vector.reduction , %arg0 : vector<16xi32> into i32 + return %0 : i32 +} +// CHECK-LABEL: @reduce_and_i32( +// CHECK-SAME: %[[A:.*]]: vector<16xi32>) +// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.and"(%[[A]]) +// CHECK: return %[[V]] : i32 + +// ----- + +func.func @reduce_and_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 { + %0 = vector.reduction , %arg0, %arg1 : vector<16xi32> into i32 + return %0 : i32 +} +// CHECK-LABEL: @reduce_and_acc_i32( +// CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32) +// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.and"(%[[A]]) +// CHECK: %[[V:.*]] = llvm.and %[[ACC]], %[[R]] +// CHECK: return %[[V]] : i32 + +// ----- + +func.func @reduce_or_i32(%arg0: vector<16xi32>) -> i32 { + %0 = vector.reduction , %arg0 : vector<16xi32> into i32 + return %0 : i32 +} +// CHECK-LABEL: @reduce_or_i32( +// CHECK-SAME: %[[A:.*]]: vector<16xi32>) +// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.or"(%[[A]]) +// CHECK: return %[[V]] : i32 + +// ----- + +func.func @reduce_or_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 { + %0 = vector.reduction , %arg0, %arg1 : vector<16xi32> into i32 + return %0 : i32 +} +// CHECK-LABEL: @reduce_or_acc_i32( +// CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32) +// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.or"(%[[A]]) +// CHECK: %[[V:.*]] = llvm.or %[[ACC]], %[[R]] +// CHECK: return %[[V]] : i32 + +// ----- + +func.func @reduce_xor_i32(%arg0: vector<16xi32>) -> i32 { + %0 = vector.reduction , %arg0 : vector<16xi32> into i32 + return %0 : i32 +} +// CHECK-LABEL: @reduce_xor_i32( +// CHECK-SAME: %[[A:.*]]: vector<16xi32>) +// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.xor"(%[[A]]) +// CHECK: return %[[V]] : i32 + +// ----- + +func.func @reduce_xor_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 { + %0 = vector.reduction , %arg0, %arg1 : vector<16xi32> into i32 + return %0 : i32 +} +// CHECK-LABEL: @reduce_xor_acc_i32( +// CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32) +// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.xor"(%[[A]]) +// CHECK: %[[V:.*]] = llvm.xor %[[ACC]], %[[R]] +// CHECK: return %[[V]] : i32 + +// ----- + func.func @reduce_i64(%arg0: vector<16xi64>) -> i64 { %0 = vector.reduction , %arg0 : vector<16xi64> into i64 return %0 : i64 diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1116,27 +1116,6 @@ // ----- -func.func @reduce_unsupported_accumulator_kind(%arg0: vector<16xf32>, %arg1: f32) -> f32 { - // expected-error@+1 {{'vector.reduction' op no accumulator for reduction kind: min}} - %0 = vector.reduction , %arg0, %arg1 : vector<16xf32> into f32 -} - -// ----- - -func.func @reduce_unsupported_accumulator_type(%arg0: vector<16xi32>, %arg1: i32) -> i32 { - // expected-error@+1 {{'vector.reduction' op no accumulator for type: 'i32'}} - %0 = vector.reduction , %arg0, %arg1 : vector<16xi32> into i32 -} - -// ----- - -func.func @reduce_unsupported_type(%arg0: vector<16xf32>) -> f32 { - // expected-error@+1 {{'vector.reduction' op unsupported reduction type}} - %0 = vector.reduction , %arg0 : vector<16xf32> into f32 -} - -// ----- - func.func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 { // expected-error@+1 {{'vector.reduction' op unsupported reduction rank: 2}} %0 = vector.reduction , %arg0 : vector<4x16xf32> into f32 diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -19,9 +19,8 @@ // CHECK-SAME: %[[B:.*1]]: vector<4xf32>, // CHECK-SAME: %[[C:.*2]]: f32 // CHECK: %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32> -// CHECK: %[[R:.*]] = vector.reduction , %[[F]] : vector<4xf32> into f32 -// CHECK: %[[ACC:.*]] = arith.addf %[[R]], %[[C]] : f32 -// CHECK: return %[[ACC]] : f32 +// CHECK: %[[R:.*]] = vector.reduction , %[[F]], %[[C]] : vector<4xf32> into f32 +// CHECK: return %[[R]] : f32 func.func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 { %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 @@ -34,9 +33,8 @@ // CHECK-SAME: %[[B:.*1]]: vector<4xi32>, // CHECK-SAME: %[[C:.*2]]: i32 // CHECK: %[[F:.*]] = arith.muli %[[A]], %[[B]] : vector<4xi32> -// CHECK: %[[R:.*]] = vector.reduction , %[[F]] : vector<4xi32> into i32 -// CHECK: %[[ACC:.*]] = arith.addi %[[R]], %[[C]] : i32 -// CHECK: return %[[ACC]] : i32 +// CHECK: %[[R:.*]] = vector.reduction , %[[F]], %[[C]] : vector<4xi32> into i32 +// CHECK: return %[[R]] : i32 func.func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %arg2: i32) -> i32 { %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 @@ -72,7 +70,7 @@ func.func @extract_contract2(%arg0: vector<2x3xf32>, %arg1: vector<3xf32>, - %arg2: vector<2xf32>) -> vector<2xf32> { + %arg2: vector<2xf32>) -> vector<2xf32> { %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 : vector<2x3xf32>, vector<3xf32> into vector<2xf32> return %0 : vector<2xf32> @@ -95,7 +93,7 @@ // CHECK: return %[[T10]] : vector<2xi32> func.func @extract_contract2_int(%arg0: vector<2x3xi32>, %arg1: vector<3xi32>, - %arg2: vector<2xi32>) -> vector<2xi32> { + %arg2: vector<2xi32>) -> vector<2xi32> { %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 : vector<2x3xi32>, vector<3xi32> into vector<2xi32> return %0 : vector<2xi32> @@ -201,18 +199,16 @@ // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> // CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> // CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[T1]] : vector<3xf32> -// CHECK: %[[T3:.*]] = vector.reduction , %[[T2]] : vector<3xf32> into f32 -// CHECK: %[[T4:.*]] = arith.addf %[[T3]], %[[C]] : f32 +// CHECK: %[[T3:.*]] = vector.reduction , %[[T2]], %[[C]] : vector<3xf32> into f32 // CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> // CHECK: %[[T6:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> // CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[T6]] : vector<3xf32> -// CHECK: %[[T8:.*]] = vector.reduction , %[[T7]] : vector<3xf32> into f32 -// CHECK: %[[T9:.*]] = arith.addf %[[T8]], %[[T4]] : f32 -// CHECK: return %[[T9]] : f32 +// CHECK: %[[T8:.*]] = vector.reduction , %[[T7]], %[[T3]] : vector<3xf32> into f32 +// CHECK: return %[[T8]] : f32 func.func @full_contract1(%arg0: vector<2x3xf32>, %arg1: vector<2x3xf32>, - %arg2: f32) -> f32 { + %arg2: f32) -> f32 { %0 = vector.contract #contraction2d_trait %arg0, %arg1, %arg2 : vector<2x3xf32>, vector<2x3xf32> into f32 return %0 : f32 @@ -241,8 +237,7 @@ // CHECK: %[[T7:.*]] = vector.extract %[[B]][2, 0] : vector<3x2xf32> // CHECK: %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32> // CHECK: %[[T10:.*]] = arith.mulf %[[T0]], %[[T9]] : vector<3xf32> -// CHECK: %[[T11:.*]] = vector.reduction , %[[T10]] : vector<3xf32> into f32 -// CHECK: %[[ACC0:.*]] = arith.addf %[[T11]], %[[C]] : f32 +// CHECK: %[[T11:.*]] = vector.reduction , %[[T10]], %[[C]] : vector<3xf32> into f32 // // CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> // CHECK: %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf @@ -252,13 +247,12 @@ // CHECK: %[[T19:.*]] = vector.extract %[[B]][2, 1] : vector<3x2xf32> // CHECK: %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32> // CHECK: %[[T22:.*]] = arith.mulf %[[T12]], %[[T21]] : vector<3xf32> -// CHECK: %[[T23:.*]] = vector.reduction , %[[T22]] : vector<3xf32> into f32 -// CHECK: %[[ACC1:.*]] = arith.addf %[[T23]], %[[ACC0]] : f32 -// CHECK: return %[[ACC1]] : f32 +// CHECK: %[[T23:.*]] = vector.reduction , %[[T22]], %[[T11]] : vector<3xf32> into f32 +// CHECK: return %[[T23]] : f32 func.func @full_contract2(%arg0: vector<2x3xf32>, %arg1: vector<3x2xf32>, - %arg2: f32) -> f32 { + %arg2: f32) -> f32 { %0 = vector.contract #contraction2d_trans_trait %arg0, %arg1, %arg2 : vector<2x3xf32>, vector<3x2xf32> into f32 return %0 : f32