diff --git a/mlir/include/mlir/Dialect/Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/Vector/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Vector/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Vector/CMakeLists.txt @@ -1,2 +1,4 @@ add_mlir_dialect(VectorOps vector) +mlir_tablegen(VectorOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(VectorOpsEnums.cpp.inc -gen-enum-defs) add_mlir_doc(VectorOps -gen-op-doc VectorOps Dialects/) diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -22,6 +22,9 @@ #include "mlir/Interfaces/VectorInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" +// Pull in all enum type definitions and utility function declarations. +#include "mlir/Dialect/Vector/VectorOpsEnums.h.inc" + namespace mlir { class MLIRContext; class OwningRewritePatternList; diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -37,6 +37,24 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } +// The "kind" of combining function for contractions and reductions. +def COMBINING_KIND_ADD : StrEnumAttrCase<"ADD", 1, "add">; +def COMBINING_KIND_MUL : StrEnumAttrCase<"MUL", 2, "mul">; +def COMBINING_KIND_MIN : StrEnumAttrCase<"MIN", 3, "min">; +def COMBINING_KIND_MAX : StrEnumAttrCase<"MAX", 4, "max">; +def COMBINING_KIND_AND : StrEnumAttrCase<"AND", 5, "and">; +def COMBINING_KIND_OR : StrEnumAttrCase<"OR", 6, "or">; +def COMBINING_KIND_XOR : StrEnumAttrCase<"XOR", 7, "xor">; + +def CombiningKindAttr : StrEnumAttr< + "CombiningKind", + "Kind of combining function for contractions and reductions", + [COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MIN, + COMBINING_KIND_MAX, COMBINING_KIND_AND, COMBINING_KIND_OR, + COMBINING_KIND_XOR]> { + let cppNamespace = "::mlir"; +} + // TODO: Add an attribute to specify a different algebra with operators other // than the current set: {*, +}. def Vector_ContractionOp : @@ -47,7 +65,8 @@ TCresVTEtIsSameAsOpBase<0, 2>>, DeclareOpInterfaceMethods ]>, - Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc, + Arguments<(ins DefaultValuedAttr:$kind, + AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc, Variadic>:$masks, AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>, Results<(outs AnyType)> { @@ -88,6 +107,11 @@ and acc arguments. An indexing map attribute specifies a mapping from each iterator in the iterator type list, to each dimension of an N-D vector. + An optional kind attribute may be used to specify the combining function + between the intermediate result and accumulator argument of rank K. This + attribute can take the values add/mul/min/max for int/fp, and/or/xor for + int only. The default is "add". + Example: ```mlir @@ -146,6 +170,20 @@ // types than accumulator/result. %6 = vector.contract #contraction_trait %0, %1, %2 : vector<10xf16>, vector<10xf16> into f32 + + // Contract with max (K = 0). + #contraction_accesses = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (i)>, + affine_map<(i) -> ()> + ] + #contraction_trait = { + indexing_maps = #contraction_accesses, + iterator_types = ["reduction"], + kind = "max" + } + %7 = vector.contract #contraction_trait %0, %1, %2 + : vector<10xf32>, vector<10xf32> into f32 ``` }]; let builders = [ @@ -189,6 +227,17 @@ std::vector> getContractingDimMap(); std::vector> getBatchDimMap(); + + static constexpr StringRef getKindAttrName() { return "kind"; } + + CombiningKind getCombiningKind() { + return symbolizeEnum((*this)->getAttrOfType( + getKindAttrName()).getValue()).getValue(); + } + + static StringRef getDefaultKind() { + return stringifyCombiningKind(CombiningKind::ADD); + } }]; } @@ -820,7 +869,9 @@ TCresVTEtIsSameAsOpBase<0, 0>>, PredOpTrait<"rhs operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 1>>]>, - Arguments<(ins AnyVector:$lhs, AnyType:$rhs, Variadic:$acc)>, + Arguments<(ins AnyVector:$lhs, AnyType:$rhs, + DefaultValuedAttr:$kind, + Variadic:$acc)>, Results<(outs AnyVector)> { let summary = "vector outerproduct with optional fused add"; let description = [{ @@ -846,6 +897,12 @@ lowered to the LLVMIR dialect, this form emits `llvm.intr.fma`, which is guaranteed to lower to actual `fma` instructions on x86. + An optional kind attribute may be specified to be add/mul/min/max + for int/fp, and and/or/xor for int only. The default is "add", in which + case the operation returns a fused multiply-add. In other cases it returns + a multiply followed by the appropriate operation (for example, a compare and + select for "max"). + Example: ``` @@ -856,6 +913,10 @@ vector<4xf32>, vector<8xf32>, vector<4x8xf32> return %3: vector<4x8xf32> + %4 = vector.outerproduct %0, %1, %2 {kind = "max"}: + vector<4xf32>, vector<8xf32>, vector<4x8xf32> + return %3: vector<4x8xf32> + %6 = vector.outerproduct %4, %5: vector<10xf32>, f32 return %6: vector<10xf32> @@ -880,6 +941,16 @@ VectorType getVectorType() { return getResult().getType().cast(); } + static constexpr StringRef getKindAttrName() { + return "kind"; + } + CombiningKind getCombiningKind() { + return symbolizeEnum((*this)->getAttrOfType( + getKindAttrName()).getValue()).getValue(); + } + static StringRef getDefaultKind() { + return stringifyCombiningKind(CombiningKind::ADD); + } }]; } diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1102,11 +1102,11 @@ } // An enum attribute case stored with StringAttr. -class StrEnumAttrCase : - EnumAttrCaseInfo, +class StrEnumAttrCase : + EnumAttrCaseInfo, StringBasedAttr< - CPred<"$_self.cast<::mlir::StringAttr>().getValue() == \"" # sym # "\"">, - "case " # sym>; + CPred<"$_self.cast<::mlir::StringAttr>().getValue() == \"" # str # "\"">, + "case " # str>; // An enum attribute case stored with IntegerAttr, which has an integer value, // its representation as a string and a C++ symbol name which may be different. diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -27,6 +27,9 @@ #include "llvm/ADT/StringSet.h" #include +// Pull in all enum type and utility function definitions. +#include "mlir/Dialect/Vector/VectorOpsEnums.cpp.inc" + using namespace mlir; using namespace mlir::vector; @@ -182,6 +185,8 @@ AffineMap::inferFromExprList(indexingExprs))); result.addAttribute(getIteratorTypesAttrName(), builder.getStrArrayAttr(iteratorTypes)); + result.addAttribute(ContractionOp::getKindAttrName(), + builder.getStringAttr(ContractionOp::getDefaultKind())); } void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, @@ -192,6 +197,8 @@ result.addTypes(acc.getType()); result.addAttribute(getIndexingMapsAttrName(), indexingMaps); result.addAttribute(getIteratorTypesAttrName(), iteratorTypes); + result.addAttribute(ContractionOp::getKindAttrName(), + builder.getStringAttr(ContractionOp::getDefaultKind())); } static ParseResult parseContractionOp(OpAsmParser &parser, @@ -220,6 +227,11 @@ return failure(); result.attributes.assign(dictAttr.getValue().begin(), dictAttr.getValue().end()); + if (!result.attributes.get(ContractionOp::getKindAttrName())) { + result.addAttribute( + ContractionOp::getKindAttrName(), + StringAttr::get(ContractionOp::getDefaultKind(), result.getContext())); + } if (masksInfo.empty()) return success(); if (masksInfo.size() != 2) @@ -420,12 +432,35 @@ rhsMaskType.getShape().size() != rhsType.getShape().size()) return op.emitOpError("invalid vector mask rank"); } + + // Verify supported combining kind + CombiningKind combiningKind = op.getCombiningKind(); + auto vectorType = resType.dyn_cast(); + auto elementType = vectorType ? vectorType.getElementType() : resType; + switch (combiningKind) { + case CombiningKind::ADD: + case CombiningKind::MUL: + case CombiningKind::MIN: + case CombiningKind::MAX: + if (!elementType.isIntOrIndexOrFloat()) { + return op.emitOpError("unsupported contraction type"); + } + break; + case CombiningKind::AND: + case CombiningKind::OR: + case CombiningKind::XOR: + if (!elementType.isIntOrIndex()) { + return op.emitOpError("unsupported contraction type"); + } + } + return success(); } ArrayRef ContractionOp::getTraitAttrNames() { - static constexpr StringRef names[2] = {getIndexingMapsAttrName(), - getIteratorTypesAttrName()}; + static constexpr StringRef names[3] = {getIndexingMapsAttrName(), + getIteratorTypesAttrName(), + ContractionOp::getKindAttrName()}; return llvm::makeArrayRef(names); } @@ -1492,12 +1527,17 @@ Value lhs, Value rhs, Value acc) { result.addOperands({lhs, rhs, acc}); result.addTypes(acc.getType()); + result.addAttribute( + OuterProductOp::getKindAttrName(), + StringAttr::get(OuterProductOp::getDefaultKind(), result.getContext())); } static void print(OpAsmPrinter &p, OuterProductOp op) { p << op.getOperationName() << " " << op.lhs() << ", " << op.rhs(); - if (!op.acc().empty()) + if (!op.acc().empty()) { p << ", " << op.acc(); + p.printOptionalAttrDict(op.getAttrs()); + } p << " : " << op.lhs().getType() << ", " << op.rhs().getType(); } @@ -1505,8 +1545,10 @@ OperationState &result) { SmallVector operandsInfo; Type tLHS, tRHS; - if (parser.parseOperandList(operandsInfo) || parser.parseColonType(tLHS) || - parser.parseComma() || parser.parseType(tRHS)) + if (parser.parseOperandList(operandsInfo) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(tLHS) || parser.parseComma() || + parser.parseType(tRHS)) return failure(); if (operandsInfo.size() < 2) return parser.emitError(parser.getNameLoc(), @@ -1520,6 +1562,13 @@ vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, vLHS.getElementType()) : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType()); + + if (!result.attributes.get(OuterProductOp::getKindAttrName())) { + result.attributes.append( + OuterProductOp::getKindAttrName(), + StringAttr::get(OuterProductOp::getDefaultKind(), result.getContext())); + } + return failure( parser.resolveOperand(operandsInfo[0], tLHS, result.operands) || parser.resolveOperand(operandsInfo[1], tRHS, result.operands) || @@ -1557,6 +1606,27 @@ if (vACC && vACC != vRES) return op.emitOpError("expected operand #3 of same type as result type"); + + // Verify supported combining kind + CombiningKind combiningKind = op.getCombiningKind(); + auto elementType = vRES.getElementType(); + switch (combiningKind) { + case CombiningKind::ADD: + case CombiningKind::MUL: + case CombiningKind::MIN: + case CombiningKind::MAX: + if (!elementType.isIntOrIndexOrFloat()) { + return op.emitOpError("unsupported outerproduct type"); + } + break; + case CombiningKind::AND: + case CombiningKind::OR: + case CombiningKind::XOR: + if (!elementType.isIntOrIndex()) { + return op.emitOpError("unsupported outerproduct type"); + } + } + return success(); } diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1355,11 +1355,18 @@ Type eltType = resType.getElementType(); bool isInt = eltType.isa(); Value acc = (op.acc().empty()) ? nullptr : op.acc()[0]; + CombiningKind kind = op.getCombiningKind(); if (!rhsType) { // Special case: AXPY operation. Value b = rewriter.create(loc, lhsType, op.rhs()); - rewriter.replaceOp(op, genMult(loc, op.lhs(), b, acc, isInt, rewriter)); + Optional mult = + isInt ? genMultI(loc, op.lhs(), b, acc, kind, rewriter) + : genMultF(loc, op.lhs(), b, acc, kind, rewriter); + if (!mult.hasValue()) { + return failure(); + } + rewriter.replaceOp(op, mult.getValue()); return success(); } @@ -1372,25 +1379,93 @@ Value r = nullptr; if (acc) r = rewriter.create(loc, rhsType, acc, pos); - Value m = genMult(loc, a, op.rhs(), r, isInt, rewriter); - result = rewriter.create(loc, resType, m, result, pos); + Optional m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter) + : genMultF(loc, a, op.rhs(), r, kind, rewriter); + if (!m.hasValue()) { + return failure(); + } + result = rewriter.create(loc, resType, m.getValue(), + result, pos); } rewriter.replaceOp(op, result); return success(); } private: - static Value genMult(Location loc, Value x, Value y, Value acc, bool isInt, - PatternRewriter &rewriter) { - if (acc) { - if (isInt) - return rewriter.create(loc, rewriter.create(loc, x, y), - acc); - return rewriter.create(loc, x, y, acc); + static Optional genMultI(Location loc, Value x, Value y, Value acc, + CombiningKind kind, + PatternRewriter &rewriter) { + MulIOp mul = rewriter.create(loc, x, y); + + if (!acc) + return Optional(mul); + + Value combinedResult; + switch (kind) { + case CombiningKind::ADD: + combinedResult = rewriter.create(loc, mul, acc); + break; + case CombiningKind::MUL: + combinedResult = rewriter.create(loc, mul, acc); + break; + case CombiningKind::MIN: + combinedResult = rewriter.create( + loc, rewriter.create(loc, CmpIPredicate::slt, mul, acc), mul, + acc); + break; + case CombiningKind::MAX: + combinedResult = rewriter.create( + loc, rewriter.create(loc, CmpIPredicate::sge, mul, acc), mul, + acc); + break; + case CombiningKind::AND: + combinedResult = rewriter.create(loc, mul, acc); + break; + case CombiningKind::OR: + combinedResult = rewriter.create(loc, mul, acc); + break; + case CombiningKind::XOR: + combinedResult = rewriter.create(loc, mul, acc); + break; + } + return Optional(combinedResult); + } + + static Optional genMultF(Location loc, Value x, Value y, Value acc, + CombiningKind kind, + PatternRewriter &rewriter) { + // Special case for fused multiply-add. + if (acc && kind == CombiningKind::ADD) { + return Optional(rewriter.create(loc, x, y, acc)); + } + + MulFOp mul = rewriter.create(loc, x, y); + + if (!acc) + return Optional(mul); + + Value combinedResult; + switch (kind) { + case CombiningKind::MUL: + combinedResult = rewriter.create(loc, mul, acc); + break; + case CombiningKind::MIN: + combinedResult = rewriter.create( + loc, rewriter.create(loc, CmpFPredicate::OLE, mul, acc), mul, + acc); + break; + case CombiningKind::MAX: + combinedResult = rewriter.create( + loc, rewriter.create(loc, CmpFPredicate::OGT, mul, acc), mul, + acc); + break; + case CombiningKind::ADD: // Already handled this special case above. + case CombiningKind::AND: // Only valid for integer types + case CombiningKind::OR: // Only valid for integer types + case CombiningKind::XOR: // Only valid for integer types + return Optional(); } - if (isInt) - return rewriter.create(loc, x, y); - return rewriter.create(loc, x, y); + return Optional(combinedResult); } }; @@ -1805,7 +1880,8 @@ for (int64_t k = 0; k < reductionSize; ++k) { Value a = rewriter.create(op.getLoc(), lhs, k); Value b = rewriter.create(op.getLoc(), rhs, k); - res = rewriter.create(op.getLoc(), a, b, res); + res = rewriter.create(op.getLoc(), res.getType(), a, + b, op.kind(), res); } rewriter.replaceOp(op, res); return success(); diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -67,7 +67,7 @@ // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32> // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32> // CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32> -// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> +// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32> func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>, @@ -86,7 +86,7 @@ // CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32> // CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32> // CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32> -// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32> +// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32> // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32> func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>, @@ -311,6 +311,6 @@ // CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32> // CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32> // CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32> -// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[V2]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32> +// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[V0]], %[[V1]], %[[V2]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32> // CHECK: %[[W:.*]] = vector.transfer_write %[[C]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32> // CHECK: return %[[W]] : tensor<8x12xf32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -198,13 +198,34 @@ func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 { // CHECK: %[[C0:.*]] = constant 0.000000e+00 : f32 %f0 = constant 0.0: f32 - // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"]} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32 + // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = "add"} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32 %0 = vector.contract #contraction_to_scalar_trait %arg0, %arg1, %f0 : vector<10xf32>, vector<10xf32> into f32 // CHECK: return %[[X]] : f32 return %0 : f32 } +#contraction_to_scalar_max_accesses = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (i)>, + affine_map<(i) -> ()> +] +#contraction_to_scalar_max_trait = { + indexing_maps = #contraction_to_scalar_max_accesses, + iterator_types = ["reduction"], + kind = "max" +} +// CHECK-LABEL: @contraction_to_scalar_with_max +func @contraction_to_scalar_with_max(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 { + // CHECK: %[[C0:.*]] = constant 0.000000e+00 : f32 + %f0 = constant 0.0: f32 + // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = "max"} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32 + %0 = vector.contract #contraction_to_scalar_max_trait %arg0, %arg1, %f0 + : vector<10xf32>, vector<10xf32> into f32 + // CHECK: return %[[X]] : f32 + return %0 : f32 +} + #contraction_accesses0 = [ affine_map<(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0)>, affine_map<(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1)>, @@ -221,36 +242,46 @@ // 8, 8, 15, 5 affine_map<(f0, f1, f2, f3, c0, c1) -> (f0, f1, f2, f3)> ] +#iterator_types1 = ["parallel", "parallel", "parallel", "parallel", "reduction", + "reduction"] #contraction_trait1 = { indexing_maps = #contraction_accesses1, - iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", - "reduction"] + iterator_types = #iterator_types1 +} +#contraction_trait2 = { + indexing_maps = #contraction_accesses1, + iterator_types = #iterator_types1, + kind = "max" } // CHECK-LABEL: @contraction func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>, %arg2 : vector<8x15x5xf32>, %arg3 : vector<8x8x15x5xf32>, %arg4 : vector<7x8x16x15xf16>, %arg5 : vector<8x16x7x5xf16>) { // Test contraction with batch and contracting dims. - // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"], kind = "add"} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> // Test contraction with only contracting dims. In this case the lhs/rhs // dimension of size 8 will be considered a parallel dim for lhs/rhs and will // appear twice in the output. - // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = "add"} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> %1 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3 : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> // Test contraction with optional vector mask arguments. %lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1> %rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1> - // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = "add"} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> %2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask, %rhs_mask : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> // Test contraction with mixed type. - // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32> + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = "add"} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32> %3 = vector.contract #contraction_trait1 %arg4, %arg5, %arg3 : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32> + // Test contraction with "max" instead of "add". + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = "max"} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> + %4 = vector.contract #contraction_trait2 %arg0, %arg1, %arg3 + : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> return } diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir @@ -50,10 +50,10 @@ // CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32> // CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32> -// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] : vector<2xf32>, f32 +// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = "add"} : vector<2xf32>, f32 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32> // CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32> -// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] : vector<2xf32>, f32 +// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = "add"} : vector<2xf32>, f32 // CHECK: store %[[T9]], %[[C]][] : memref> // CHECK: return func @matvec2x2(%arg0: memref>, %arg1: memref>, @@ -75,10 +75,10 @@ // CHECK: %[[T2:.*]] = load %[[C]][] : memref> // CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32> // CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32> -// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] : vector<2xf32>, f32 +// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = "add"} : vector<2xf32>, f32 // CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32> // CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32> -// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] : vector<2xf32>, f32 +// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = "add"} : vector<2xf32>, f32 // CHECK: store %[[T8]], %[[C]][] : memref> // CHECK: return func @mattransvec2x2(%arg0: memref>, %arg1: memref>, @@ -101,10 +101,10 @@ // CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32> // CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32> -// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] : vector<2xf32>, f32 +// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = "add"} : vector<2xf32>, f32 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32> // CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32> -// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] : vector<2xf32>, f32 +// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = "add"} : vector<2xf32>, f32 // CHECK: store %[[T9]], %[[C]][] : memref> // CHECK: return func @vecmat2x2(%arg0: memref>, %arg1: memref>, @@ -126,10 +126,10 @@ // CHECK: %[[T2:.*]] = load %[[C]][] : memref> // CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32> // CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32> -// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] : vector<2xf32>, f32 +// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = "add"} : vector<2xf32>, f32 // CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32> // CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32> -// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] : vector<2xf32>, f32 +// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = "add"} : vector<2xf32>, f32 // CHECK: store %[[T8]], %[[C]][] : memref> // CHECK: return func @vecmattrans2x2(%arg0: memref>, %arg1: memref>, diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -92,56 +92,56 @@ // CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES3]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES4]], 0 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> // CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES5]], 0 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES1]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES2]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES4]], 1 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> // CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES5]], 2 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG6]], %[[TG7]], %[[R1S00]], %[[TG8]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG6]], %[[TG7]], %[[R1S00]], %[[TG8]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES1]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES2]], 4 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES4]], 2 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> // CHECK-NEXT: %[[TG13:.*]] = vector.tuple_get %[[ES5]], 4 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG10]], %[[TG11]], %[[R2S00]], %[[TG12]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG10]], %[[TG11]], %[[R2S00]], %[[TG12]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [0, 2] // CHECK-NEXT: %[[TG14:.*]] = vector.tuple_get %[[ES2]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG15:.*]] = vector.tuple_get %[[ES3]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG16:.*]] = vector.tuple_get %[[ES5]], 1 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG1]], %[[TG14]], %[[TG15]], %[[TG4]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG1]], %[[TG14]], %[[TG15]], %[[TG4]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[TG17:.*]] = vector.tuple_get %[[ES2]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG18:.*]] = vector.tuple_get %[[ES5]], 3 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG6]], %[[TG17]], %[[R1S02]], %[[TG8]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG6]], %[[TG17]], %[[R1S02]], %[[TG8]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[TG19:.*]] = vector.tuple_get %[[ES2]], 5 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG20:.*]] = vector.tuple_get %[[ES5]], 5 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG10]], %[[TG19]], %[[R2S02]], %[[TG12]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG10]], %[[TG19]], %[[R2S02]], %[[TG12]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 0] // CHECK-NEXT: %[[TG21:.*]] = vector.tuple_get %[[ES1]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG22:.*]] = vector.tuple_get %[[ES3]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG23:.*]] = vector.tuple_get %[[ES4]], 3 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG21]], %[[TG2]], %[[TG22]], %[[TG23]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG21]], %[[TG2]], %[[TG22]], %[[TG23]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[TG24:.*]] = vector.tuple_get %[[ES1]], 4 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG25:.*]] = vector.tuple_get %[[ES4]], 4 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG24]], %[[TG7]], %[[R1S20]], %[[TG25]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG24]], %[[TG7]], %[[R1S20]], %[[TG25]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[TG26:.*]] = vector.tuple_get %[[ES1]], 5 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG27:.*]] = vector.tuple_get %[[ES4]], 5 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG26]], %[[TG11]], %[[R2S20]], %[[TG27]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG26]], %[[TG11]], %[[R2S20]], %[[TG27]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 2] // CHECK-NEXT: %[[TG28:.*]] = vector.tuple_get %[[ES3]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG21]], %[[TG14]], %[[TG28]], %[[TG23]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG24]], %[[TG17]], %[[R1S22]], %[[TG25]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG26]], %[[TG19]], %[[R2S22]], %[[TG27]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG21]], %[[TG14]], %[[TG28]], %[[TG23]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG24]], %[[TG17]], %[[R1S22]], %[[TG25]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = "add"} %[[TG26]], %[[TG19]], %[[R2S22]], %[[TG27]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[RES0:.*]] = vector.tuple %[[R3S00]], %[[R3S02]], %[[R3S20]], %[[R3S22]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> // CHECK-NEXT: %[[RES1:.*]] = vector.insert_slices %[[RES0]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32> @@ -187,26 +187,26 @@ // CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES3]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES4]], 0 : tuple, vector<2x2xi1>> // CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES5]], 0 : tuple, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = "add"} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [0, 2] // CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES2]], 1 : tuple, vector<2x2xf32>> // CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES3]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES5]], 1 : tuple, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG1]], %[[TG6]], %[[TG7]], %[[TG4]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = "add"} %[[TG1]], %[[TG6]], %[[TG7]], %[[TG4]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 0] // CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES1]], 1 : tuple, vector<2x2xf32>> // CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES3]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES4]], 1 : tuple, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG9]], %[[TG2]], %[[TG10]], %[[TG11]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = "add"} %[[TG9]], %[[TG2]], %[[TG10]], %[[TG11]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 2] // CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES3]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG9]], %[[TG6]], %[[TG12]], %[[TG11]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = "add"} %[[TG9]], %[[TG6]], %[[TG12]], %[[TG11]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[RES0:.*]] = vector.tuple %[[R1S00]], %[[R1S02]], %[[R1S20]], %[[R1S22]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> // CHECK-NEXT: %[[RES1:.*]] = vector.insert_slices %[[RES0]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32> @@ -241,10 +241,10 @@ // CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> // CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = "add"} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = "add"} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = "add"} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = "add"} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32> // CHECK-NEXT: vector.transfer_write %[[R1]], %{{.*}}[%[[C0]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32> @@ -572,10 +572,10 @@ // CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> // CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = "add"} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = "add"} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = "add"} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = "add"} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32> // CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[R1]], %[[VTW0]][%[[C0]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32>