diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1716,6 +1716,24 @@ } // namespace +/// Creates an AddIOp if `isInt` is true otherwise create an AddFOp using +/// operands `x` and `y`. +static Value createAdd(Location loc, Value x, Value y, bool isInt, + PatternRewriter &rewriter) { + if (isInt) + return rewriter.create(loc, x, y); + return rewriter.create(loc, x, y); +} + +/// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using +/// operands `x and `y`. +static Value createMul(Location loc, Value x, Value y, bool isInt, + PatternRewriter &rewriter) { + if (isInt) + return rewriter.create(loc, x, y); + return rewriter.create(loc, x, y); +} + namespace mlir { /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul @@ -2003,13 +2021,14 @@ // ExtractOp does not allow dynamic indexing, we must unroll explicitly. Value res = rewriter.create(loc, dstType, rewriter.getZeroAttr(dstType)); + bool isInt = dstType.getElementType().isa(); for (unsigned r = 0; r < dstRows; ++r) { Value a = rewriter.create(op.getLoc(), lhs, r); for (unsigned c = 0; c < dstColumns; ++c) { Value b = rank == 1 ? rhs : rewriter.create(op.getLoc(), rhs, c); - Value m = rewriter.create(op.getLoc(), a, b); + Value m = createMul(op.getLoc(), a, b, isInt, rewriter); Value reduced = rewriter.create( op.getLoc(), dstType.getElementType(), rewriter.getStringAttr("add"), m, ValueRange{}); @@ -2020,7 +2039,7 @@ } } if (auto acc = op.acc()) - res = rewriter.create(op.getLoc(), res, acc); + res = createAdd(op.getLoc(), res, acc, isInt, rewriter); rewriter.replaceOp(op, res); return success(); } @@ -2176,6 +2195,7 @@ VectorType rhsType = op.getRhsType(); Type resType = op.getResultType(); assert(!resType.isa()); + bool isInt = resType.isa(); // Use iterator index 0. int64_t iterIndex = 0; SmallVector iMap = op.getIndexingMaps(); @@ -2190,10 +2210,13 @@ // Base case. if (lhsType.getRank() == 1) { assert(rhsType.getRank() == 1 && "corrupt contraction"); - Value m = rewriter.create(loc, op.lhs(), op.rhs()); + Value m = createMul(loc, op.lhs(), op.rhs(), isInt, rewriter); StringAttr kind = rewriter.getStringAttr("add"); - return rewriter.create(loc, resType, kind, m, - op.acc()); + Value res = rewriter.create(loc, resType, kind, m, + ValueRange{}); + if (auto acc = op.acc()) + res = createAdd(op.getLoc(), res, acc, isInt, rewriter); + return res; } // Construct new iterator types and affine map array attribute. std::array lowIndexingMaps = { diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -18,8 +18,9 @@ // CHECK-SAME: %[[B:.*1]]: vector<4xf32>, // CHECK-SAME: %[[C:.*2]]: f32 // CHECK: %[[F:.*]] = mulf %[[A]], %[[B]] : vector<4xf32> -// CHECK: %[[R:.*]] = vector.reduction "add", %[[F]], %[[C]] : vector<4xf32> into f32 -// CHECK: return %[[R]] : f32 +// CHECK: %[[R:.*]] = vector.reduction "add", %[[F]] : vector<4xf32> into f32 +// CHECK: %[[ACC:.*]] = addf %[[R]], %[[C]] : f32 +// CHECK: return %[[ACC]] : f32 func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 { %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 @@ -27,6 +28,21 @@ return %0 : f32 } +// CHECK-LABEL: func @extract_contract1_int +// CHECK-SAME: %[[A:.*0]]: vector<4xi32>, +// CHECK-SAME: %[[B:.*1]]: vector<4xi32>, +// CHECK-SAME: %[[C:.*2]]: i32 +// CHECK: %[[F:.*]] = muli %[[A]], %[[B]] : vector<4xi32> +// CHECK: %[[R:.*]] = vector.reduction "add", %[[F]] : vector<4xi32> into i32 +// CHECK: %[[ACC:.*]] = addi %[[R]], %[[C]] : i32 +// CHECK: return %[[ACC]] : i32 + +func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %arg2: i32) -> i32 { + %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 + : vector<4xi32>, vector<4xi32> into i32 + return %0 : i32 +} + #matvec_accesses = [ affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (j)>, @@ -61,6 +77,29 @@ return %0 : vector<2xf32> } +// CHECK-LABEL: func @extract_contract2_int +// CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>, +// CHECK-SAME: %[[B:.*1]]: vector<3xi32>, +// CHECK-SAME: %[[C:.*2]]: vector<2xi32> +// CHECK: %[[R:.*]] = constant dense<0> : vector<2xi32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xi32> +// CHECK: %[[T2:.*]] = muli %[[T0]], %[[B]] : vector<3xi32> +// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xi32> into i32 +// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : i32 into vector<2xi32> +// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xi32> +// CHECK: %[[T7:.*]] = muli %[[T5]], %[[B]] : vector<3xi32> +// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xi32> into i32 +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : i32 into vector<2xi32> +// CHECK: %[[T10:.*]] = addi %[[T9]], %[[C]] : vector<2xi32> +// CHECK: return %[[T10]] : vector<2xi32> +func @extract_contract2_int(%arg0: vector<2x3xi32>, + %arg1: vector<3xi32>, + %arg2: vector<2xi32>) -> vector<2xi32> { + %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 + : vector<2x3xi32>, vector<3xi32> into vector<2xi32> + return %0 : vector<2xi32> +} + #vecmat_accesses = [ affine_map<(i, j) -> (j)>, affine_map<(i, j) -> (i, j)>, @@ -162,12 +201,14 @@ // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> // CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> // CHECK: %[[T2:.*]] = mulf %[[T0]], %[[T1]] : vector<3xf32> -// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[C]] : vector<3xf32> into f32 -// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> -// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> -// CHECK: %[[T6:.*]] = mulf %[[T4]], %[[T5]] : vector<3xf32> -// CHECK: %[[T7:.*]] = vector.reduction "add", %[[T6]], %[[T3]] : vector<3xf32> into f32 -// CHECK: return %[[T7]] : f32 +// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xf32> into f32 +// CHECK: %[[T4:.*]] = addf %[[T3]], %[[C]] : f32 +// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> +// CHECK: %[[T7:.*]] = mulf %[[T5]], %[[T6]] : vector<3xf32> +// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xf32> into f32 +// CHECK: %[[T9:.*]] = addf %[[T8]], %[[T4]] : f32 +// CHECK: return %[[T9]] : f32 func @full_contract1(%arg0: vector<2x3xf32>, %arg1: vector<2x3xf32>, @@ -200,7 +241,8 @@ // CHECK: %[[T7:.*]] = vector.extract %[[B]][2, 0] : vector<3x2xf32> // CHECK: %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32> // CHECK: %[[T10:.*]] = mulf %[[T0]], %[[T9]] : vector<3xf32> -// CHECK: %[[T11:.*]] = vector.reduction "add", %[[T10]], %[[C]] : vector<3xf32> into f32 +// CHECK: %[[T11:.*]] = vector.reduction "add", %[[T10]] : vector<3xf32> into f32 +// CHECK: %[[ACC0:.*]] = addf %[[T11]], %[[C]] : f32 // // CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> // CHECK: %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf @@ -210,8 +252,9 @@ // CHECK: %[[T19:.*]] = vector.extract %[[B]][2, 1] : vector<3x2xf32> // CHECK: %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32> // CHECK: %[[T22:.*]] = mulf %[[T12]], %[[T21]] : vector<3xf32> -// CHECK: %[[T23:.*]] = vector.reduction "add", %[[T22]], %[[T11]] : vector<3xf32> into f32 -// CHECK: return %[[T23]] : f32 +// CHECK: %[[T23:.*]] = vector.reduction "add", %[[T22]] : vector<3xf32> into f32 +// CHECK: %[[ACC1:.*]] = addf %[[T23]], %[[ACC0]] : f32 +// CHECK: return %[[ACC1]] : f32 func @full_contract2(%arg0: vector<2x3xf32>, %arg1: vector<3x2xf32>,