Index: mlir/lib/Dialect/Vector/VectorTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1893,6 +1893,7 @@ } VectorType dstType = op.getResultType().cast(); + bool isInt = dstType.getElementType().isa(); assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && "Expected dst type of rank 1 or 2"); @@ -1909,7 +1910,11 @@ Value b = rank == 1 ? rhs : rewriter.create(op.getLoc(), rhs, c); - Value m = rewriter.create(op.getLoc(), a, b); + Value m; + if (isInt) + m = rewriter.create(op.getLoc(), a, b); + else + m = rewriter.create(op.getLoc(), a, b); Value reduced = rewriter.create( op.getLoc(), dstType.getElementType(), rewriter.getStringAttr("add"), m, ValueRange{}); @@ -1919,8 +1924,12 @@ res = rewriter.create(op.getLoc(), reduced, res, pos); } } - if (auto acc = op.acc()) - res = rewriter.create(op.getLoc(), res, acc); + if (auto acc = op.acc()) { + if (isInt) + res = rewriter.create(op.getLoc(), res, acc); + else + res = rewriter.create(op.getLoc(), res, acc); + } rewriter.replaceOp(op, res); return success(); } @@ -2076,6 +2085,7 @@ VectorType rhsType = op.getRhsType(); Type resType = op.getResultType(); assert(!resType.isa()); + bool isInt = resType.isa(); // Use iterator index 0. int64_t iterIndex = 0; SmallVector iMap = op.getIndexingMaps(); @@ -2090,10 +2100,21 @@ // Base case. if (lhsType.getRank() == 1) { assert(rhsType.getRank() == 1 && "corrupt contraction"); - Value m = rewriter.create(loc, op.lhs(), op.rhs()); + Value m; + if (isInt) + m = rewriter.create(loc, op.lhs(), op.rhs()); + else + m = rewriter.create(loc, op.lhs(), op.rhs()); StringAttr kind = rewriter.getStringAttr("add"); - return rewriter.create(loc, resType, kind, m, - op.acc()); + Value res = rewriter.create(loc, resType, kind, m, + ValueRange{}); + if (auto acc = op.acc()) { + if (isInt) + res = rewriter.create(op.getLoc(), res, acc); + else + res = rewriter.create(op.getLoc(), res, acc); + } + return res; } // Construct new iterator types and affine map array attribute. std::array lowIndexingMaps = { Index: mlir/test/Dialect/Vector/vector-contract-transforms.mlir =================================================================== --- mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -18,8 +18,9 @@ // CHECK-SAME: %[[B:.*1]]: vector<4xf32>, // CHECK-SAME: %[[C:.*2]]: f32 // CHECK: %[[F:.*]] = mulf %[[A]], %[[B]] : vector<4xf32> -// CHECK: %[[R:.*]] = vector.reduction "add", %[[F]], %[[C]] : vector<4xf32> into f32 -// CHECK: return %[[R]] : f32 +// CHECK: %[[R:.*]] = vector.reduction "add", %[[F]] : vector<4xf32> into f32 +// CHECK: %[[ACC:.*]] = addf %[[R]], %[[C]] : f32 +// CHECK: return %[[ACC]] : f32 func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 { %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 @@ -27,6 +28,21 @@ return %0 : f32 } +// CHECK-LABEL: func @extract_contract1_int +// CHECK-SAME: %[[A:.*0]]: vector<4xi32>, +// CHECK-SAME: %[[B:.*1]]: vector<4xi32>, +// CHECK-SAME: %[[C:.*2]]: i32 +// CHECK: %[[F:.*]] = muli %[[A]], %[[B]] : vector<4xi32> +// CHECK: %[[R:.*]] = vector.reduction "add", %[[F]] : vector<4xi32> into i32 +// CHECK: %[[ACC:.*]] = addi %[[R]], %[[C]] : i32 +// CHECK: return %[[ACC]] : i32 + +func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %arg2: i32) -> i32 { + %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 + : vector<4xi32>, vector<4xi32> into i32 + return %0 : i32 +} + #matvec_accesses = [ affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (j)>, @@ -61,6 +77,29 @@ return %0 : vector<2xf32> } +// CHECK-LABEL: func @extract_contract2_int +// CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>, +// CHECK-SAME: %[[B:.*1]]: vector<3xi32>, +// CHECK-SAME: %[[C:.*2]]: vector<2xi32> +// CHECK: %[[R:.*]] = constant dense<0> : vector<2xi32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xi32> +// CHECK: %[[T2:.*]] = muli %[[T0]], %[[B]] : vector<3xi32> +// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xi32> into i32 +// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : i32 into vector<2xi32> +// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xi32> +// CHECK: %[[T7:.*]] = muli %[[T5]], %[[B]] : vector<3xi32> +// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xi32> into i32 +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : i32 into vector<2xi32> +// CHECK: %[[T10:.*]] = addi %[[T9]], %[[C]] : vector<2xi32> +// CHECK: return %[[T10]] : vector<2xi32> +func @extract_contract2_int(%arg0: vector<2x3xi32>, + %arg1: vector<3xi32>, + %arg2: vector<2xi32>) -> vector<2xi32> { + %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 + : vector<2x3xi32>, vector<3xi32> into vector<2xi32> + return %0 : vector<2xi32> +} + #vecmat_accesses = [ affine_map<(i, j) -> (j)>, affine_map<(i, j) -> (i, j)>, @@ -162,12 +201,14 @@ // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> // CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> // CHECK: %[[T2:.*]] = mulf %[[T0]], %[[T1]] : vector<3xf32> -// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[C]] : vector<3xf32> into f32 -// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> -// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> -// CHECK: %[[T6:.*]] = mulf %[[T4]], %[[T5]] : vector<3xf32> -// CHECK: %[[T7:.*]] = vector.reduction "add", %[[T6]], %[[T3]] : vector<3xf32> into f32 -// CHECK: return %[[T7]] : f32 +// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xf32> into f32 +// CHECK: %[[T4:.*]] = addf %[[T3]], %[[C]] : f32 +// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> +// CHECK: %[[T7:.*]] = mulf %[[T5]], %[[T6]] : vector<3xf32> +// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xf32> into f32 +// CHECK: %[[T9:.*]] = addf %[[T8]], %[[T4]] : f32 +// CHECK: return %[[T9]] : f32 func @full_contract1(%arg0: vector<2x3xf32>, %arg1: vector<2x3xf32>, @@ -200,7 +241,8 @@ // CHECK: %[[T7:.*]] = vector.extract %[[B]][2, 0] : vector<3x2xf32> // CHECK: %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32> // CHECK: %[[T10:.*]] = mulf %[[T0]], %[[T9]] : vector<3xf32> -// CHECK: %[[T11:.*]] = vector.reduction "add", %[[T10]], %[[C]] : vector<3xf32> into f32 +// CHECK: %[[T11:.*]] = vector.reduction "add", %[[T10]] : vector<3xf32> into f32 +// CHECK: %[[ACC0:.*]] = addf %[[T11]], %[[C]] : f32 // // CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> // CHECK: %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf @@ -210,8 +252,9 @@ // CHECK: %[[T19:.*]] = vector.extract %[[B]][2, 1] : vector<3x2xf32> // CHECK: %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32> // CHECK: %[[T22:.*]] = mulf %[[T12]], %[[T21]] : vector<3xf32> -// CHECK: %[[T23:.*]] = vector.reduction "add", %[[T22]], %[[T11]] : vector<3xf32> into f32 -// CHECK: return %[[T23]] : f32 +// CHECK: %[[T23:.*]] = vector.reduction "add", %[[T22]] : vector<3xf32> into f32 +// CHECK: %[[ACC1:.*]] = addf %[[T23]], %[[ACC0]] : f32 +// CHECK: return %[[ACC1]] : f32 func @full_contract2(%arg0: vector<2x3xf32>, %arg1: vector<3x2xf32>,