diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -408,8 +408,6 @@ if (getKind() != CombiningKind::ADD && getKind() != CombiningKind::MUL) return emitOpError("no accumulator for reduction kind: ") << stringifyCombiningKind(getKind()); - if (!eltType.isa()) - return emitOpError("no accumulator for type: ") << eltType; } return success(); @@ -1969,7 +1967,7 @@ (static_cast(srcVectorType.getRank()) + positionAttr.size() != static_cast(destVectorType.getRank()))) return emitOpError("expected position attribute rank + source rank to " - "match dest vector rank"); + "match dest vector rank"); if (!srcVectorType && (positionAttr.size() != static_cast(destVectorType.getRank()))) return emitOpError( @@ -2302,8 +2300,7 @@ int64_t numFixedVectorSizes = fixedVectorSizes.size(); if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes) - return emitError("invalid input shape for vector type ") - << inputVectorType; + return emitError("invalid input shape for vector type ") << inputVectorType; if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes) return emitError("invalid output shape for vector type ") @@ -2396,24 +2393,29 @@ auto sizes = getSizesAttr(); auto strides = getStridesAttr(); if (offsets.size() != sizes.size() || offsets.size() != strides.size()) - return emitOpError("expected offsets, sizes and strides attributes of same size"); + return emitOpError( + "expected offsets, sizes and strides attributes of same size"); auto shape = type.getShape(); auto offName = getOffsetsAttrName(); auto sizesName = getSizesAttrName(); auto stridesName = getStridesAttrName(); - if (failed(isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) || - failed(isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) || + if (failed( + isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) || + failed( + isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) || failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape, stridesName)) || - failed(isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) || + failed( + isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) || failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName, /*halfOpen=*/false, /*min=*/1)) || - failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, stridesName, + failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, + stridesName, /*halfOpen=*/false)) || - failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes, shape, - offName, sizesName, + failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes, + shape, offName, sizesName, /*halfOpen=*/false))) return failure(); @@ -4223,7 +4225,7 @@ if (sourceVectorType.getRank() == 0) { if (sourceElementBits != resultElementBits) return emitOpError("source/result bitwidth of the 0-D vector element " - "types must be equal"); + "types must be equal"); } else if (sourceElementBits * sourceVectorType.getShape().back() != resultElementBits * resultVectorType.getShape().back()) { return emitOpError( diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1875,10 +1875,9 @@ assert(rhsType.getRank() == 1 && "corrupt contraction"); Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter); auto kind = vector::CombiningKind::ADD; - Value res = rewriter.create(loc, kind, m); if (auto acc = op.getAcc()) - res = createAdd(op.getLoc(), res, acc, isInt, rewriter); - return res; + return rewriter.create(loc, kind, m, acc); + return rewriter.create(loc, kind, m); } // Construct new iterator types and affine map array attribute. std::array lowIndexingMaps = { diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1123,13 +1123,6 @@ // ----- -func.func @reduce_unsupported_accumulator_type(%arg0: vector<16xi32>, %arg1: i32) -> i32 { - // expected-error@+1 {{'vector.reduction' op no accumulator for type: 'i32'}} - %0 = vector.reduction , %arg0, %arg1 : vector<16xi32> into i32 -} - -// ----- - func.func @reduce_unsupported_type(%arg0: vector<16xf32>) -> f32 { // expected-error@+1 {{'vector.reduction' op unsupported reduction type}} %0 = vector.reduction , %arg0 : vector<16xf32> into f32 diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -19,9 +19,8 @@ // CHECK-SAME: %[[B:.*1]]: vector<4xf32>, // CHECK-SAME: %[[C:.*2]]: f32 // CHECK: %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32> -// CHECK: %[[R:.*]] = vector.reduction , %[[F]] : vector<4xf32> into f32 -// CHECK: %[[ACC:.*]] = arith.addf %[[R]], %[[C]] : f32 -// CHECK: return %[[ACC]] : f32 +// CHECK: %[[R:.*]] = vector.reduction , %[[F]], %[[C]] : vector<4xf32> into f32 +// CHECK: return %[[R]] : f32 func.func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 { %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 @@ -34,9 +33,8 @@ // CHECK-SAME: %[[B:.*1]]: vector<4xi32>, // CHECK-SAME: %[[C:.*2]]: i32 // CHECK: %[[F:.*]] = arith.muli %[[A]], %[[B]] : vector<4xi32> -// CHECK: %[[R:.*]] = vector.reduction , %[[F]] : vector<4xi32> into i32 -// CHECK: %[[ACC:.*]] = arith.addi %[[R]], %[[C]] : i32 -// CHECK: return %[[ACC]] : i32 +// CHECK: %[[R:.*]] = vector.reduction , %[[F]], %[[C]] : vector<4xi32> into i32 +// CHECK: return %[[R]] : i32 func.func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %arg2: i32) -> i32 { %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 @@ -72,7 +70,7 @@ func.func @extract_contract2(%arg0: vector<2x3xf32>, %arg1: vector<3xf32>, - %arg2: vector<2xf32>) -> vector<2xf32> { + %arg2: vector<2xf32>) -> vector<2xf32> { %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 : vector<2x3xf32>, vector<3xf32> into vector<2xf32> return %0 : vector<2xf32> @@ -95,7 +93,7 @@ // CHECK: return %[[T10]] : vector<2xi32> func.func @extract_contract2_int(%arg0: vector<2x3xi32>, %arg1: vector<3xi32>, - %arg2: vector<2xi32>) -> vector<2xi32> { + %arg2: vector<2xi32>) -> vector<2xi32> { %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 : vector<2x3xi32>, vector<3xi32> into vector<2xi32> return %0 : vector<2xi32> @@ -201,18 +199,16 @@ // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> // CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> // CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[T1]] : vector<3xf32> -// CHECK: %[[T3:.*]] = vector.reduction , %[[T2]] : vector<3xf32> into f32 -// CHECK: %[[T4:.*]] = arith.addf %[[T3]], %[[C]] : f32 +// CHECK: %[[T3:.*]] = vector.reduction , %[[T2]], %[[C]] : vector<3xf32> into f32 // CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> // CHECK: %[[T6:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> // CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[T6]] : vector<3xf32> -// CHECK: %[[T8:.*]] = vector.reduction , %[[T7]] : vector<3xf32> into f32 -// CHECK: %[[T9:.*]] = arith.addf %[[T8]], %[[T4]] : f32 -// CHECK: return %[[T9]] : f32 +// CHECK: %[[T8:.*]] = vector.reduction , %[[T7]], %[[T3]] : vector<3xf32> into f32 +// CHECK: return %[[T8]] : f32 func.func @full_contract1(%arg0: vector<2x3xf32>, %arg1: vector<2x3xf32>, - %arg2: f32) -> f32 { + %arg2: f32) -> f32 { %0 = vector.contract #contraction2d_trait %arg0, %arg1, %arg2 : vector<2x3xf32>, vector<2x3xf32> into f32 return %0 : f32 @@ -241,8 +237,7 @@ // CHECK: %[[T7:.*]] = vector.extract %[[B]][2, 0] : vector<3x2xf32> // CHECK: %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32> // CHECK: %[[T10:.*]] = arith.mulf %[[T0]], %[[T9]] : vector<3xf32> -// CHECK: %[[T11:.*]] = vector.reduction , %[[T10]] : vector<3xf32> into f32 -// CHECK: %[[ACC0:.*]] = arith.addf %[[T11]], %[[C]] : f32 +// CHECK: %[[T11:.*]] = vector.reduction , %[[T10]], %[[C]] : vector<3xf32> into f32 // // CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> // CHECK: %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf @@ -252,13 +247,12 @@ // CHECK: %[[T19:.*]] = vector.extract %[[B]][2, 1] : vector<3x2xf32> // CHECK: %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32> // CHECK: %[[T22:.*]] = arith.mulf %[[T12]], %[[T21]] : vector<3xf32> -// CHECK: %[[T23:.*]] = vector.reduction , %[[T22]] : vector<3xf32> into f32 -// CHECK: %[[ACC1:.*]] = arith.addf %[[T23]], %[[ACC0]] : f32 -// CHECK: return %[[ACC1]] : f32 +// CHECK: %[[T23:.*]] = vector.reduction , %[[T22]], %[[T11]] : vector<3xf32> into f32 +// CHECK: return %[[T23]] : f32 func.func @full_contract2(%arg0: vector<2x3xf32>, %arg1: vector<3x2xf32>, - %arg2: f32) -> f32 { + %arg2: f32) -> f32 { %0 = vector.contract #contraction2d_trans_trait %arg0, %arg1, %arg2 : vector<2x3xf32>, vector<3x2xf32> into f32 return %0 : f32