diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -47,7 +47,8 @@ TCresVTEtIsSameAsOpBase<0, 2>>, DeclareOpInterfaceMethods ]>, - Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc, + Arguments<(ins DefaultValuedAttr:$kind, + AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc, Variadic>:$masks, AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>, Results<(outs AnyType)> { @@ -88,6 +89,10 @@ and acc arguments. An indexing map attribute specifies a mapping from each iterator in the iterator type list, to each dimension of an N-D vector. + An optional kind attribute may be used to specify the joining operation + between the intermediate result and accumulator argument of rank K. This + attribute can take the values {"add", "max"}. The default is "add". + Example: ```mlir @@ -146,6 +151,20 @@ // types than accumulator/result. %6 = vector.contract #contraction_trait %0, %1, %2 : vector<10xf16>, vector<10xf16> into f32 + + // Contract with max (K = 0). + #contraction_accesses = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (i)>, + affine_map<(i) -> ()> + ] + #contraction_trait = { + indexing_maps = #contraction_accesses, + iterator_types = ["reduction"], + kind = "max" + } + %7 = vector.contract #contraction_trait %0, %1, %2 + : vector<10xf32>, vector<10xf32> into f32 ``` }]; let builders = [ @@ -819,9 +838,11 @@ TCresVTEtIsSameAsOpBase<0, 0>>, PredOpTrait<"rhs operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 1>>]>, - Arguments<(ins AnyVector:$lhs, AnyType:$rhs, Variadic:$acc)>, + Arguments<(ins AnyVector:$lhs, AnyType:$rhs, + DefaultValuedAttr:$kind, + Variadic:$acc)>, Results<(outs AnyVector)> { - let summary = "vector outerproduct with optional fused add"; + let summary = "vector outerproduct with optional max or fused add"; let description = [{ Takes 2 1-D vectors and returns the 2-D vector containing the outer-product, as illustrated below: @@ -845,6 +866,10 @@ lowered to the LLVMIR dialect, this form emits `llvm.intr.fma`, which is guaranteed to lower to actual `fma` instructions on x86. + An optional kind attribute may be specified to be one of {"add", "max"}. + The default is "add". In case of "max", the operation returns the max + of the outer-product and the extra vector (if present). + Example: ``` @@ -855,6 +880,10 @@ vector<4xf32>, vector<8xf32>, vector<4x8xf32> return %3: vector<4x8xf32> + %4 = vector.outerproduct %0, %1, %2 {kind = "max"}: + vector<4xf32>, vector<8xf32>, vector<4x8xf32> + return %3: vector<4x8xf32> + %6 = vector.outerproduct %4, %5: vector<10xf32>, f32 return %6: vector<10xf32> diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -196,6 +196,7 @@ AffineMap::inferFromExprList(indexingExprs))); result.addAttribute(getIteratorTypesAttrName(), builder.getStrArrayAttr(iteratorTypes)); + result.addAttribute("kind", builder.getStringAttr("add")); } void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, @@ -206,6 +207,7 @@ result.addTypes(acc.getType()); result.addAttribute(getIndexingMapsAttrName(), indexingMaps); result.addAttribute(getIteratorTypesAttrName(), iteratorTypes); + result.addAttribute("kind", builder.getStringAttr("add")); } static ParseResult parseContractionOp(OpAsmParser &parser, @@ -234,6 +236,10 @@ return failure(); result.attributes.assign(dictAttr.getValue().begin(), dictAttr.getValue().end()); + if (!result.attributes.get("kind")) { + result.attributes.append("kind", + StringAttr::get("add", result.getContext())); + } if (masksInfo.empty()) return success(); if (masksInfo.size() != 2) @@ -438,8 +444,8 @@ } ArrayRef ContractionOp::getTraitAttrNames() { - static constexpr StringRef names[2] = {getIndexingMapsAttrName(), - getIteratorTypesAttrName()}; + static constexpr StringRef names[3] = {getIndexingMapsAttrName(), + getIteratorTypesAttrName(), "kind"}; return llvm::makeArrayRef(names); } @@ -1476,12 +1482,15 @@ Value lhs, Value rhs, Value acc) { result.addOperands({lhs, rhs, acc}); result.addTypes(acc.getType()); + result.addAttribute("kind", StringAttr::get("add", result.getContext())); } static void print(OpAsmPrinter &p, OuterProductOp op) { p << op.getOperationName() << " " << op.lhs() << ", " << op.rhs(); - if (!op.acc().empty()) + if (!op.acc().empty()) { p << ", " << op.acc(); + p.printOptionalAttrDict(op.getAttrs()); + } p << " : " << op.lhs().getType() << ", " << op.rhs().getType(); } @@ -1489,8 +1498,10 @@ OperationState &result) { SmallVector operandsInfo; Type tLHS, tRHS; - if (parser.parseOperandList(operandsInfo) || parser.parseColonType(tLHS) || - parser.parseComma() || parser.parseType(tRHS)) + if (parser.parseOperandList(operandsInfo) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(tLHS) || parser.parseComma() || + parser.parseType(tRHS)) return failure(); if (operandsInfo.size() < 2) return parser.emitError(parser.getNameLoc(), @@ -1504,6 +1515,12 @@ vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, vLHS.getElementType()) : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType()); + + if (operandsInfo.size() > 2 && !result.attributes.get("kind")) { + result.attributes.append("kind", + StringAttr::get("add", result.getContext())); + } + return failure( parser.resolveOperand(operandsInfo[0], tLHS, result.operands) || parser.resolveOperand(operandsInfo[1], tRHS, result.operands) || diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1339,10 +1339,18 @@ bool isInt = eltType.isa(); Value acc = (op.acc().empty()) ? nullptr : op.acc()[0]; + constexpr StringRef valid_kinds[2] = {"add", "max"}; + StringRef kind = op.kind(); + if (!llvm::is_contained(valid_kinds, kind)) { + return failure(); + } + if (!rhsType) { // Special case: AXPY operation. Value b = rewriter.create(loc, lhsType, op.rhs()); - rewriter.replaceOp(op, genMult(loc, op.lhs(), b, acc, isInt, rewriter)); + Value mult = isInt ? genMultI(loc, op.lhs(), b, acc, kind, rewriter) + : genMultF(loc, op.lhs(), b, acc, kind, rewriter); + rewriter.replaceOp(op, mult); return success(); } @@ -1355,7 +1363,8 @@ Value r = nullptr; if (acc) r = rewriter.create(loc, rhsType, acc, pos); - Value m = genMult(loc, a, op.rhs(), r, isInt, rewriter); + Value m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter) + : genMultF(loc, a, op.rhs(), r, kind, rewriter); result = rewriter.create(loc, resType, m, result, pos); } rewriter.replaceOp(op, result); @@ -1363,16 +1372,35 @@ } private: - static Value genMult(Location loc, Value x, Value y, Value acc, bool isInt, - PatternRewriter &rewriter) { + static Value genMultI(Location loc, Value x, Value y, Value acc, + StringRef kind, PatternRewriter &rewriter) { + MulIOp mul = rewriter.create(loc, x, y); if (acc) { - if (isInt) - return rewriter.create(loc, rewriter.create(loc, x, y), - acc); - return rewriter.create(loc, x, y, acc); + if (kind == "add") { + return rewriter.create(loc, mul, acc); + } + if (kind == "max") { + return rewriter.create( + loc, rewriter.create(loc, CmpIPredicate::sge, mul, acc), + mul, acc); + } + } + return mul; + } + + static Value genMultF(Location loc, Value x, Value y, Value acc, + StringRef kind, PatternRewriter &rewriter) { + if (acc) { + if (kind == "add") { + return rewriter.create(loc, x, y, acc); + } + if (kind == "max") { + MulFOp mul = rewriter.create(loc, x, y); + return rewriter.create( + loc, rewriter.create(loc, CmpFPredicate::OGT, mul, acc), + mul, acc); + } } - if (isInt) - return rewriter.create(loc, x, y); return rewriter.create(loc, x, y); } }; @@ -1788,7 +1816,8 @@ for (int64_t k = 0; k < reductionSize; ++k) { Value a = rewriter.create(op.getLoc(), lhs, k); Value b = rewriter.create(op.getLoc(), rhs, k); - res = rewriter.create(op.getLoc(), a, b, res); + res = rewriter.create(op.getLoc(), res.getType(), a, + b, op.kind(), res); } rewriter.replaceOp(op, res); return success(); diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -67,7 +67,7 @@ // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32> // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32> // CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32> -// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> +// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32> func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>, @@ -86,7 +86,7 @@ // CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32> // CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32> // CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32> -// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32> +// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32> // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32> func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>, diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -150,13 +150,34 @@ func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 { // CHECK: %[[C0:.*]] = constant 0.000000e+00 : f32 %f0 = constant 0.0: f32 - // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"]} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32 + // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = "add"} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32 %0 = vector.contract #contraction_to_scalar_trait %arg0, %arg1, %f0 : vector<10xf32>, vector<10xf32> into f32 // CHECK: return %[[X]] : f32 return %0 : f32 } +#contraction_to_scalar_max_accesses = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (i)>, + affine_map<(i) -> ()> +] +#contraction_to_scalar_max_trait = { + indexing_maps = #contraction_to_scalar_max_accesses, + iterator_types = ["reduction"], + kind = "max" +} +// CHECK-LABEL: @contraction_to_scalar_with_max +func @contraction_to_scalar_with_max(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 { + // CHECK: %[[C0:.*]] = constant 0.000000e+00 : f32 + %f0 = constant 0.0: f32 + // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = "max"} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32 + %0 = vector.contract #contraction_to_scalar_max_trait %arg0, %arg1, %f0 + : vector<10xf32>, vector<10xf32> into f32 + // CHECK: return %[[X]] : f32 + return %0 : f32 +} + #contraction_accesses0 = [ affine_map<(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0)>, affine_map<(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1)>, @@ -173,36 +194,46 @@ // 8, 8, 15, 5 affine_map<(f0, f1, f2, f3, c0, c1) -> (f0, f1, f2, f3)> ] +#iterator_types1 = ["parallel", "parallel", "parallel", "parallel", "reduction", + "reduction"] #contraction_trait1 = { indexing_maps = #contraction_accesses1, - iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", - "reduction"] + iterator_types = #iterator_types1 +} +#contraction_trait2 = { + indexing_maps = #contraction_accesses1, + iterator_types = #iterator_types1, + kind = "max" } // CHECK-LABEL: @contraction func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>, %arg2 : vector<8x15x5xf32>, %arg3 : vector<8x8x15x5xf32>, %arg4 : vector<7x8x16x15xf16>, %arg5 : vector<8x16x7x5xf16>) { // Test contraction with batch and contracting dims. - // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"], kind = "add"} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> // Test contraction with only contracting dims. In this case the lhs/rhs // dimension of size 8 will be considered a parallel dim for lhs/rhs and will // appear twice in the output. - // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = "add"} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> %1 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3 : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> // Test contraction with optional vector mask arguments. %lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1> %rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1> - // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = "add"} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> %2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask, %rhs_mask : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> // Test contraction with mixed type. - // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32> + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = "add"} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32> %3 = vector.contract #contraction_trait1 %arg4, %arg5, %arg3 : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32> + // Test contraction with "max" instead of "add". + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = "max"} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> + %4 = vector.contract #contraction_trait2 %arg0, %arg1, %arg3 + : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> return } diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir @@ -50,10 +50,10 @@ // CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32> // CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32> -// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] : vector<2xf32>, f32 +// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = "add"} : vector<2xf32>, f32 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32> // CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32> -// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] : vector<2xf32>, f32 +// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = "add"} : vector<2xf32>, f32 // CHECK: store %[[T9]], %[[C]][] : memref> // CHECK: return func @matvec2x2(%arg0: memref>, %arg1: memref>, @@ -75,10 +75,10 @@ // CHECK: %[[T2:.*]] = load %[[C]][] : memref> // CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32> // CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32> -// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] : vector<2xf32>, f32 +// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = "add"} : vector<2xf32>, f32 // CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32> // CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32> -// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] : vector<2xf32>, f32 +// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = "add"} : vector<2xf32>, f32 // CHECK: store %[[T8]], %[[C]][] : memref> // CHECK: return func @mattransvec2x2(%arg0: memref>, %arg1: memref>, @@ -101,10 +101,10 @@ // CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32> // CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32> -// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] : vector<2xf32>, f32 +// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = "add"} : vector<2xf32>, f32 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32> // CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32> -// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] : vector<2xf32>, f32 +// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = "add"} : vector<2xf32>, f32 // CHECK: store %[[T9]], %[[C]][] : memref> // CHECK: return func @vecmat2x2(%arg0: memref>, %arg1: memref>, @@ -126,10 +126,10 @@ // CHECK: %[[T2:.*]] = load %[[C]][] : memref> // CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32> // CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32> -// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] : vector<2xf32>, f32 +// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = "add"} : vector<2xf32>, f32 // CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32> // CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32> -// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] : vector<2xf32>, f32 +// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = "add"} : vector<2xf32>, f32 // CHECK: store %[[T8]], %[[C]][] : memref> // CHECK: return func @vecmattrans2x2(%arg0: memref>, %arg1: memref>, diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -93,56 +93,56 @@ // CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES3]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES4]], 0 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> // CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES5]], 0 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES1]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES2]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES4]], 1 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> // CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES5]], 2 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG6]], %[[TG7]], %[[R1S00]], %[[TG8]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG6]], %[[TG7]], %[[R1S00]], %[[TG8]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES1]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES2]], 4 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES4]], 2 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> // CHECK-NEXT: %[[TG13:.*]] = vector.tuple_get %[[ES5]], 4 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG10]], %[[TG11]], %[[R2S00]], %[[TG12]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG10]], %[[TG11]], %[[R2S00]], %[[TG12]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [0, 2] // CHECK-NEXT: %[[TG14:.*]] = vector.tuple_get %[[ES2]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG15:.*]] = vector.tuple_get %[[ES3]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG16:.*]] = vector.tuple_get %[[ES5]], 1 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG1]], %[[TG14]], %[[TG15]], %[[TG4]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG1]], %[[TG14]], %[[TG15]], %[[TG4]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[TG17:.*]] = vector.tuple_get %[[ES2]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG18:.*]] = vector.tuple_get %[[ES5]], 3 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG6]], %[[TG17]], %[[R1S02]], %[[TG8]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG6]], %[[TG17]], %[[R1S02]], %[[TG8]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[TG19:.*]] = vector.tuple_get %[[ES2]], 5 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG20:.*]] = vector.tuple_get %[[ES5]], 5 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG10]], %[[TG19]], %[[R2S02]], %[[TG12]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG10]], %[[TG19]], %[[R2S02]], %[[TG12]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 0] // CHECK-NEXT: %[[TG21:.*]] = vector.tuple_get %[[ES1]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG22:.*]] = vector.tuple_get %[[ES3]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG23:.*]] = vector.tuple_get %[[ES4]], 3 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG21]], %[[TG2]], %[[TG22]], %[[TG23]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG21]], %[[TG2]], %[[TG22]], %[[TG23]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[TG24:.*]] = vector.tuple_get %[[ES1]], 4 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG25:.*]] = vector.tuple_get %[[ES4]], 4 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG24]], %[[TG7]], %[[R1S20]], %[[TG25]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG24]], %[[TG7]], %[[R1S20]], %[[TG25]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[TG26:.*]] = vector.tuple_get %[[ES1]], 5 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG27:.*]] = vector.tuple_get %[[ES4]], 5 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG26]], %[[TG11]], %[[R2S20]], %[[TG27]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG26]], %[[TG11]], %[[R2S20]], %[[TG27]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 2] // CHECK-NEXT: %[[TG28:.*]] = vector.tuple_get %[[ES3]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG21]], %[[TG14]], %[[TG28]], %[[TG23]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG24]], %[[TG17]], %[[R1S22]], %[[TG25]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG26]], %[[TG19]], %[[R2S22]], %[[TG27]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG21]], %[[TG14]], %[[TG28]], %[[TG23]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG24]], %[[TG17]], %[[R1S22]], %[[TG25]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG26]], %[[TG19]], %[[R2S22]], %[[TG27]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[RES0:.*]] = vector.tuple %[[R3S00]], %[[R3S02]], %[[R3S20]], %[[R3S22]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> // CHECK-NEXT: %[[RES1:.*]] = vector.insert_slices %[[RES0]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32> @@ -188,26 +188,26 @@ // CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES3]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES4]], 0 : tuple, vector<2x2xi1>> // CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES5]], 0 : tuple, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = "add"} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [0, 2] // CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES2]], 1 : tuple, vector<2x2xf32>> // CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES3]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES5]], 1 : tuple, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG1]], %[[TG6]], %[[TG7]], %[[TG4]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = "add"} %[[TG1]], %[[TG6]], %[[TG7]], %[[TG4]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 0] // CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES1]], 1 : tuple, vector<2x2xf32>> // CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES3]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES4]], 1 : tuple, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG9]], %[[TG2]], %[[TG10]], %[[TG11]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = "add"} %[[TG9]], %[[TG2]], %[[TG10]], %[[TG11]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 2] // CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES3]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG9]], %[[TG6]], %[[TG12]], %[[TG11]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = "add"} %[[TG9]], %[[TG6]], %[[TG12]], %[[TG11]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[RES0:.*]] = vector.tuple %[[R1S00]], %[[R1S02]], %[[R1S20]], %[[R1S22]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> // CHECK-NEXT: %[[RES1:.*]] = vector.insert_slices %[[RES0]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32> @@ -242,10 +242,10 @@ // CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> // CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = "add"} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = "add"} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = "add"} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = "add"} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32> // CHECK-NEXT: vector.transfer_write %[[R1]], %{{.*}}[%[[C0]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32>