diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1291,8 +1291,8 @@ Value b = rewriter.create(loc, rhsType, x); Value m; if (acc) { - Value z = rewriter.create(loc, rhsType, acc, pos); - m = rewriter.create(loc, b, op.rhs(), z); + Value e = rewriter.create(loc, rhsType, acc, pos); + m = rewriter.create(loc, b, op.rhs(), e); } else { m = rewriter.create(loc, b, op.rhs()); } @@ -1732,7 +1732,7 @@ /// .. /// %x = combine %a %b .. /// until a pure contraction is reached (no free/batch dimensions), -/// which is replaced by a fma/reduction op. +/// which is replaced by a dot-product/reduction pair. /// /// TODO(ajcbik): break down into transpose/reshape/cast ops /// when they become available to avoid code dup @@ -1882,11 +1882,9 @@ // Base case. if (lhsType.getRank() == 1) { assert(rhsType.getRank() == 1 && "corrupt contraction"); - Value zero = rewriter.create(loc, lhsType, - rewriter.getZeroAttr(lhsType)); - Value fma = rewriter.create(loc, op.lhs(), op.rhs(), zero); + Value m = rewriter.create(loc, op.lhs(), op.rhs()); StringAttr kind = rewriter.getStringAttr("add"); - return rewriter.create(loc, resType, kind, fma, + return rewriter.create(loc, resType, kind, m, op.acc()); } // Construct new iterator types and affine map array attribute. diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -16,8 +16,7 @@ // CHECK-SAME: %[[A:.*0]]: vector<4xf32>, // CHECK-SAME: %[[B:.*1]]: vector<4xf32>, // CHECK-SAME: %[[C:.*2]]: f32 -// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<4xf32> -// CHECK: %[[F:.*]] = vector.fma %[[A]], %[[B]], %[[Z]] : vector<4xf32> +// CHECK: %[[F:.*]] = mulf %[[A]], %[[B]] : vector<4xf32> // CHECK: %[[R:.*]] = vector.reduction "add", %[[F]], %[[C]] : vector<4xf32> into f32 // CHECK: return %[[R]] : f32 @@ -42,15 +41,14 @@ // CHECK-SAME: %[[B:.*1]]: vector<3xf32>, // CHECK-SAME: %[[C:.*2]]: vector<2xf32> // CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32> -// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> // CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32> -// CHECK: %[[T2:.*]] = vector.fma %[[T0]], %[[B]], %[[Z]] : vector<3xf32> +// CHECK: %[[T2:.*]] = mulf %[[T0]], %[[B]] : vector<3xf32> // CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[T1]] : vector<3xf32> into f32 // CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> // CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> // CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32> -// CHECK: %[[T7:.*]] = vector.fma %[[T5]], %[[B]], %[[Z]] : vector<3xf32> +// CHECK: %[[T7:.*]] = mulf %[[T5]], %[[B]] : vector<3xf32> // CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]], %[[T6]] : vector<3xf32> into f32 // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> // CHECK: return %[[T9]] : vector<2xf32> @@ -78,15 +76,14 @@ // CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>, // CHECK-SAME: %[[C:.*2]]: vector<2xf32> // CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32> -// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> // CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32> -// CHECK: %[[T2:.*]] = vector.fma %[[A]], %[[T0]], %[[Z]] : vector<3xf32> +// CHECK: %[[T2:.*]] = mulf %[[A]], %[[T0]] : vector<3xf32> // CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[T1]] : vector<3xf32> into f32 // CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> // CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> // CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32> -// CHECK: %[[T7:.*]] = vector.fma %[[A]], %[[T5]], %[[Z]] : vector<3xf32> +// CHECK: %[[T7:.*]] = mulf %[[A]], %[[T5]] : vector<3xf32> // CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]], %[[T6]] : vector<3xf32> into f32 // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> // CHECK: return %[[T9]] : vector<2xf32> @@ -124,7 +121,7 @@ // CHECK: %[[T6:.*]] = vector.extract %[[T5]][0] : vector<2xf32> // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T4]] [1] : f32 into vector<2xf32> // CHECK: %[[T8:.*]] = vector.extract %[[T1]][0] : vector<2xf32> -// CHECK: %[[T9:.*]] = vector.fma %[[T0]], %[[T7]], %[[Z]] : vector<2xf32> +// CHECK: %[[T9:.*]] = mulf %[[T0]], %[[T7]] : vector<2xf32> // CHECK: %[[T10:.*]] = vector.reduction "add", %[[T9]], %[[T8]] : vector<2xf32> into f32 // CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[Z]] [0] : f32 into vector<2xf32> // CHECK: %[[T12:.*]] = vector.extract %[[B]][0] : vector<2x2xf32> @@ -134,7 +131,7 @@ // CHECK: %[[T16:.*]] = vector.extract %[[T15]][1] : vector<2xf32> // CHECK: %[[T17:.*]] = vector.insert %[[T16]], %[[T14]] [1] : f32 into vector<2xf32> // CHECK: %[[T18:.*]] = vector.extract %[[T1]][1] : vector<2xf32> -// CHECK: %[[T19:.*]] = vector.fma %[[T0]], %[[T17]], %[[Z]] : vector<2xf32> +// CHECK: %[[T19:.*]] = mulf %[[T0]], %[[T17]] : vector<2xf32> // CHECK: %[[T20:.*]] = vector.reduction "add", %[[T19]], %[[T18]] : vector<2xf32> into f32 // CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [1] : f32 into vector<2xf32> // CHECK: %[[T22:.*]] = vector.insert %[[T21]], %[[R]] [0] : vector<2xf32> into vector<2x2xf32> @@ -147,7 +144,7 @@ // CHECK: %[[T29:.*]] = vector.extract %[[T28]][0] : vector<2xf32> // CHECK: %[[T30:.*]] = vector.insert %[[T29]], %[[T27]] [1] : f32 into vector<2xf32> // CHECK: %[[T31:.*]] = vector.extract %[[T24]][0] : vector<2xf32> -// CHECK: %[[T32:.*]] = vector.fma %[[T23]], %[[T30]], %[[Z]] : vector<2xf32> +// CHECK: %[[T32:.*]] = mulf %[[T23]], %[[T30]] : vector<2xf32> // CHECK: %[[T33:.*]] = vector.reduction "add", %[[T32]], %[[T31]] : vector<2xf32> into f32 // CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[Z]] [0] : f32 into vector<2xf32> // CHECK: %[[T35:.*]] = vector.extract %[[B]][0] : vector<2x2xf32> @@ -157,7 +154,7 @@ // CHECK: %[[T39:.*]] = vector.extract %[[T38]][1] : vector<2xf32> // CHECK: %[[T40:.*]] = vector.insert %[[T39]], %[[T37]] [1] : f32 into vector<2xf32> // CHECK: %[[T41:.*]] = vector.extract %[[T24]][1] : vector<2xf32> -// CHECK: %[[T42:.*]] = vector.fma %[[T23]], %[[T40]], %[[Z]] : vector<2xf32> +// CHECK: %[[T42:.*]] = mulf %[[T23]], %[[T40]] : vector<2xf32> // CHECK: %[[T43:.*]] = vector.reduction "add", %[[T42]], %[[T41]] : vector<2xf32> into f32 // CHECK: %[[T44:.*]] = vector.insert %[[T43]], %[[T34]] [1] : f32 into vector<2xf32> // CHECK: %[[T45:.*]] = vector.insert %[[T44]], %[[T22]] [1] : vector<2xf32> into vector<2x2xf32> @@ -185,14 +182,13 @@ // CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, // CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>, // CHECK-SAME: %[[C:.*2]]: f32 -// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> // CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> -// CHECK: %[[T2:.*]] = vector.fma %[[T0]], %[[T1]], %[[Z]] : vector<3xf32> +// CHECK: %[[T2:.*]] = mulf %[[T0]], %[[T1]] : vector<3xf32> // CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[C]] : vector<3xf32> into f32 // CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> // CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> -// CHECK: %[[T6:.*]] = vector.fma %[[T4]], %[[T5]], %[[Z]] : vector<3xf32> +// CHECK: %[[T6:.*]] = mulf %[[T4]], %[[T5]] : vector<3xf32> // CHECK: %[[T7:.*]] = vector.reduction "add", %[[T6]], %[[T3]] : vector<3xf32> into f32 // CHECK: return %[[T7]] : f32 @@ -229,7 +225,7 @@ // CHECK: %[[T7:.*]] = vector.extract %[[B]][2] : vector<3x2xf32> // CHECK: %[[T8:.*]] = vector.extract %[[T7]][0] : vector<2xf32> // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [2] : f32 into vector<3xf32> -// CHECK: %[[T10:.*]] = vector.fma %[[T0]], %[[T9]], %[[Z]] : vector<3xf32> +// CHECK: %[[T10:.*]] = mulf %[[T0]], %[[T9]] : vector<3xf32> // CHECK: %[[T11:.*]] = vector.reduction "add", %[[T10]], %[[C]] : vector<3xf32> into f32 // CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> // CHECK: %[[T13:.*]] = vector.extract %[[B]][0] : vector<3x2xf32> @@ -241,7 +237,7 @@ // CHECK: %[[T19:.*]] = vector.extract %[[B]][2] : vector<3x2xf32> // CHECK: %[[T20:.*]] = vector.extract %[[T19]][1] : vector<2xf32> // CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T18]] [2] : f32 into vector<3xf32> -// CHECK: %[[T22:.*]] = vector.fma %[[T12]], %[[T21]], %[[Z]] : vector<3xf32> +// CHECK: %[[T22:.*]] = mulf %[[T12]], %[[T21]] : vector<3xf32> // CHECK: %[[T23:.*]] = vector.reduction "add", %[[T22]], %[[T11]] : vector<3xf32> into f32 // CHECK: return %[[T23]] : f32