diff --git a/mlir/include/mlir/Dialect/Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/Vector/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Vector/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Vector/CMakeLists.txt @@ -1,2 +1,4 @@ add_mlir_dialect(VectorOps vector) +mlir_tablegen(VectorOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(VectorOpsEnums.cpp.inc -gen-enum-defs) add_mlir_doc(VectorOps -gen-op-doc VectorOps Dialects/) diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -21,11 +21,21 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" +#include "llvm/ADT/StringExtras.h" + +// Pull in all enum type definitions and utility function declarations. +#include "mlir/Dialect/Vector/VectorOpsEnums.h.inc" namespace mlir { class MLIRContext; class OwningRewritePatternList; + namespace vector { +class VectorDialect; + +namespace detail { +struct BitmaskEnumStorage; +} // namespace detail /// Collect a set of vector-to-vector canonicalization patterns. void populateVectorToVectorCanonicalizationPatterns( @@ -46,6 +56,21 @@ void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns, MLIRContext *context); +/// An attribute that specifies the combining function for `vector.contract`, +/// and `vector.reduction`. +class CKAttr : public Attribute::AttrBase { +public: + using Base::Base; + + static CKAttr get(CombiningKind kind, MLIRContext *context); + + CombiningKind getKind() const; + + void print(DialectAsmPrinter &p) const; + static Attribute parse(DialectAsmParser &parser); +}; + /// Enum to control the lowering of `vector.contract` operations. enum class VectorContractLowering { /// Progressively lower to finer grained `vector.contract` and dot-products. diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -37,6 +37,35 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } +// The "kind" of combining function for contractions and reductions. +def COMBINING_KIND_ADD : BitEnumAttrCase<"ADD", 0x1, "add">; +def COMBINING_KIND_MUL : BitEnumAttrCase<"MUL", 0x2, "mul">; +def COMBINING_KIND_MIN : BitEnumAttrCase<"MIN", 0x4, "min">; +def COMBINING_KIND_MAX : BitEnumAttrCase<"MAX", 0x8, "max">; +def COMBINING_KIND_AND : BitEnumAttrCase<"AND", 0x10, "and">; +def COMBINING_KIND_OR : BitEnumAttrCase<"OR", 0x20, "or">; +def COMBINING_KIND_XOR : BitEnumAttrCase<"XOR", 0x40, "xor">; + +def CombiningKind : BitEnumAttr< + "CombiningKind", + "Kind of combining function for contractions and reductions", + [COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MIN, + COMBINING_KIND_MAX, COMBINING_KIND_AND, COMBINING_KIND_OR, + COMBINING_KIND_XOR]> { + let cppNamespace = "::mlir::vector"; +} + +def Vector_CKAttr : DialectAttr< + Vector_Dialect, + CPred<"$_self.isa<::mlir::vector::CKAttr>()">, + "Kind of combining function for contractions and reductions"> { + let storageType = "::mlir::vector::CKAttr"; + let returnType = "::mlir::vector::CombiningKind"; + let convertFromStorage = "$_self.getKind()"; + let constBuilderCall = + "::mlir::vector::CKAttr::get($0, $_builder.getContext())"; +} + // TODO: Add an attribute to specify a different algebra with operators other // than the current set: {*, +}. def Vector_ContractionOp : @@ -47,7 +76,8 @@ TCresVTEtIsSameAsOpBase<0, 2>>, DeclareOpInterfaceMethods ]>, - Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc, + Arguments<(ins DefaultValuedAttr:$kind, + AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc, Variadic>:$masks, AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>, Results<(outs AnyType)> { @@ -88,6 +118,11 @@ and acc arguments. An indexing map attribute specifies a mapping from each iterator in the iterator type list, to each dimension of an N-D vector. + An optional kind attribute may be used to specify the combining function + between the intermediate result and accumulator argument of rank K. This + attribute can take the values add/mul/min/max for int/fp, and/or/xor for + int only. The default is "add". + Example: ```mlir @@ -146,6 +181,20 @@ // types than accumulator/result. %6 = vector.contract #contraction_trait %0, %1, %2 : vector<10xf16>, vector<10xf16> into f32 + + // Contract with max (K = 0). + #contraction_accesses = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (i)>, + affine_map<(i) -> ()> + ] + #contraction_trait = { + indexing_maps = #contraction_accesses, + iterator_types = ["reduction"], + kind = #vector.kind + } + %7 = vector.contract #contraction_trait %0, %1, %2 + : vector<10xf32>, vector<10xf32> into f32 ``` }]; let builders = [ @@ -189,6 +238,12 @@ std::vector> getContractingDimMap(); std::vector> getBatchDimMap(); + + static constexpr StringRef getKindAttrName() { return "kind"; } + + static CombiningKind getDefaultKind() { + return CombiningKind::ADD; + } }]; } @@ -820,7 +875,9 @@ TCresVTEtIsSameAsOpBase<0, 0>>, PredOpTrait<"rhs operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 1>>]>, - Arguments<(ins AnyVector:$lhs, AnyType:$rhs, Variadic:$acc)>, + Arguments<(ins AnyVector:$lhs, AnyType:$rhs, + DefaultValuedAttr:$kind, + Variadic:$acc)>, Results<(outs AnyVector)> { let summary = "vector outerproduct with optional fused add"; let description = [{ @@ -846,6 +903,12 @@ lowered to the LLVMIR dialect, this form emits `llvm.intr.fma`, which is guaranteed to lower to actual `fma` instructions on x86. + An optional kind attribute may be specified to be add/mul/min/max + for int/fp, and and/or/xor for int only. The default is "add", in which + case the operation returns a fused multiply-add. In other cases it returns + a multiply followed by the appropriate operation (for example, a compare and + select for "max"). + Example: ``` @@ -856,6 +919,10 @@ vector<4xf32>, vector<8xf32>, vector<4x8xf32> return %3: vector<4x8xf32> + %4 = vector.outerproduct %0, %1, %2 {kind = #vector.kind}: + vector<4xf32>, vector<8xf32>, vector<4x8xf32> + return %3: vector<4x8xf32> + %6 = vector.outerproduct %4, %5: vector<10xf32>, f32 return %6: vector<10xf32> @@ -880,6 +947,12 @@ VectorType getVectorType() { return getResult().getType().cast(); } + static constexpr StringRef getKindAttrName() { + return "kind"; + } + static CombiningKind getDefaultKind() { + return CombiningKind::ADD; + } }]; } diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1127,9 +1127,9 @@ // A bit enum case stored with 32-bit IntegerAttr. `val` here is *not* the // ordinal number of the bit that is set. It is the 32-bit integer with only // one bit set. -class BitEnumAttrCase : - EnumAttrCaseInfo, - SignlessIntegerAttrBase { +class BitEnumAttrCase : + EnumAttrCaseInfo, + SignlessIntegerAttrBase { let predicate = CPred< "$_self.cast<::mlir::IntegerAttr>().getValue().getZExtValue() & " # val # "u">; diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" @@ -27,6 +28,9 @@ #include "llvm/ADT/StringSet.h" #include +// Pull in all enum type and utility function definitions. +#include "mlir/Dialect/Vector/VectorOpsEnums.cpp.inc" + using namespace mlir; using namespace mlir::vector; @@ -76,11 +80,30 @@ return MaskFormat::Unknown; } +// Helper for verifying combining kinds in contractions and reductions. +static bool isSupportedCombiningKind(CombiningKind combiningKind, + Type elementType) { + switch (combiningKind) { + case CombiningKind::ADD: + case CombiningKind::MUL: + case CombiningKind::MIN: + case CombiningKind::MAX: + return elementType.isIntOrIndexOrFloat(); + case CombiningKind::AND: + case CombiningKind::OR: + case CombiningKind::XOR: + return elementType.isIntOrIndex(); + } + return false; +} + //===----------------------------------------------------------------------===// // VectorDialect //===----------------------------------------------------------------------===// void VectorDialect::initialize() { + addAttributes(); + addOperations< #define GET_OP_LIST #include "mlir/Dialect/Vector/VectorOps.cpp.inc" @@ -104,6 +127,104 @@ return builder.getI64ArrayAttr(values); } +//===----------------------------------------------------------------------===// +// CKAttr +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace vector { +namespace detail { +struct BitmaskEnumStorage : public AttributeStorage { + using KeyTy = uint64_t; + + BitmaskEnumStorage(KeyTy val) : value(val) {} + + bool operator==(const KeyTy &key) const { return value == key; } + + static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + BitmaskEnumStorage(key); + } + + KeyTy value = 0; +}; +} // namespace detail +} // namespace vector +} // namespace mlir + +CKAttr CKAttr::get(CombiningKind kind, MLIRContext *context) { + return Base::get(context, static_cast(kind)); +} + +CombiningKind CKAttr::getKind() const { + return static_cast(getImpl()->value); +} + +static constexpr const CombiningKind CombiningKindsList[] = { + // clang-format off + CombiningKind::ADD, + CombiningKind::MUL, + CombiningKind::MIN, + CombiningKind::MAX, + CombiningKind::AND, + CombiningKind::OR, + CombiningKind::XOR, + // clang-format on +}; + +void CKAttr::print(DialectAsmPrinter &printer) const { + printer << "kind<"; + auto kinds = llvm::make_filter_range(CombiningKindsList, [&](auto kind) { + return bitEnumContains(this->getKind(), kind); + }); + llvm::interleaveComma(kinds, printer, + [&](auto kind) { printer << stringifyEnum(kind); }); + printer << ">"; +} + +Attribute CKAttr::parse(DialectAsmParser &parser) { + if (failed(parser.parseLess())) + return {}; + + StringRef elemName; + if (failed(parser.parseKeyword(&elemName))) + return {}; + + auto kind = symbolizeCombiningKind(elemName); + if (!kind) { + parser.emitError(parser.getNameLoc(), "Unknown combining kind: ") + << elemName; + return {}; + } + + if (failed(parser.parseGreater())) + return {}; + + return CKAttr::get(kind.getValue(), parser.getBuilder().getContext()); +} + +Attribute VectorDialect::parseAttribute(DialectAsmParser &parser, + Type type) const { + StringRef attrKind; + if (parser.parseKeyword(&attrKind)) + return {}; + + if (attrKind == "kind") + return CKAttr::parse(parser); + + parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind; + return {}; +} + +void VectorDialect::printAttribute(Attribute attr, + DialectAsmPrinter &os) const { + if (auto ck = attr.dyn_cast()) + ck.print(os); + else + llvm_unreachable("Unknown attribute type"); +} + //===----------------------------------------------------------------------===// // ReductionOp //===----------------------------------------------------------------------===// @@ -182,6 +303,9 @@ AffineMap::inferFromExprList(indexingExprs))); result.addAttribute(getIteratorTypesAttrName(), builder.getStrArrayAttr(iteratorTypes)); + result.addAttribute( + ContractionOp::getKindAttrName(), + CKAttr::get(ContractionOp::getDefaultKind(), builder.getContext())); } void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, @@ -192,6 +316,9 @@ result.addTypes(acc.getType()); result.addAttribute(getIndexingMapsAttrName(), indexingMaps); result.addAttribute(getIteratorTypesAttrName(), iteratorTypes); + result.addAttribute( + ContractionOp::getKindAttrName(), + CKAttr::get(ContractionOp::getDefaultKind(), builder.getContext())); } static ParseResult parseContractionOp(OpAsmParser &parser, @@ -220,6 +347,11 @@ return failure(); result.attributes.assign(dictAttr.getValue().begin(), dictAttr.getValue().end()); + if (!result.attributes.get(ContractionOp::getKindAttrName())) { + result.addAttribute( + ContractionOp::getKindAttrName(), + CKAttr::get(ContractionOp::getDefaultKind(), result.getContext())); + } if (masksInfo.empty()) return success(); if (masksInfo.size() != 2) @@ -420,12 +552,20 @@ rhsMaskType.getShape().size() != rhsType.getShape().size()) return op.emitOpError("invalid vector mask rank"); } + + // Verify supported combining kind. + auto vectorType = resType.dyn_cast(); + auto elementType = vectorType ? vectorType.getElementType() : resType; + if (!isSupportedCombiningKind(op.kind(), elementType)) + return op.emitOpError("unsupported contraction type"); + return success(); } ArrayRef ContractionOp::getTraitAttrNames() { - static constexpr StringRef names[2] = {getIndexingMapsAttrName(), - getIteratorTypesAttrName()}; + static constexpr StringRef names[3] = {getIndexingMapsAttrName(), + getIteratorTypesAttrName(), + ContractionOp::getKindAttrName()}; return llvm::makeArrayRef(names); } @@ -1492,12 +1632,17 @@ Value lhs, Value rhs, Value acc) { result.addOperands({lhs, rhs, acc}); result.addTypes(acc.getType()); + result.addAttribute( + OuterProductOp::getKindAttrName(), + CKAttr::get(OuterProductOp::getDefaultKind(), result.getContext())); } static void print(OpAsmPrinter &p, OuterProductOp op) { p << op.getOperationName() << " " << op.lhs() << ", " << op.rhs(); - if (!op.acc().empty()) + if (!op.acc().empty()) { p << ", " << op.acc(); + p.printOptionalAttrDict(op.getAttrs()); + } p << " : " << op.lhs().getType() << ", " << op.rhs().getType(); } @@ -1505,8 +1650,10 @@ OperationState &result) { SmallVector operandsInfo; Type tLHS, tRHS; - if (parser.parseOperandList(operandsInfo) || parser.parseColonType(tLHS) || - parser.parseComma() || parser.parseType(tRHS)) + if (parser.parseOperandList(operandsInfo) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(tLHS) || parser.parseComma() || + parser.parseType(tRHS)) return failure(); if (operandsInfo.size() < 2) return parser.emitError(parser.getNameLoc(), @@ -1520,6 +1667,13 @@ vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, vLHS.getElementType()) : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType()); + + if (!result.attributes.get(OuterProductOp::getKindAttrName())) { + result.attributes.append( + OuterProductOp::getKindAttrName(), + CKAttr::get(OuterProductOp::getDefaultKind(), result.getContext())); + } + return failure( parser.resolveOperand(operandsInfo[0], tLHS, result.operands) || parser.resolveOperand(operandsInfo[1], tRHS, result.operands) || @@ -1557,6 +1711,11 @@ if (vACC && vACC != vRES) return op.emitOpError("expected operand #3 of same type as result type"); + + // Verify supported combining kind. + if (!isSupportedCombiningKind(op.kind(), vRES.getElementType())) + return op.emitOpError("unsupported outerproduct type"); + return success(); } diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1355,11 +1355,17 @@ Type eltType = resType.getElementType(); bool isInt = eltType.isa(); Value acc = (op.acc().empty()) ? nullptr : op.acc()[0]; + vector::CombiningKind kind = op.kind(); if (!rhsType) { // Special case: AXPY operation. Value b = rewriter.create(loc, lhsType, op.rhs()); - rewriter.replaceOp(op, genMult(loc, op.lhs(), b, acc, isInt, rewriter)); + Optional mult = + isInt ? genMultI(loc, op.lhs(), b, acc, kind, rewriter) + : genMultF(loc, op.lhs(), b, acc, kind, rewriter); + if (!mult.hasValue()) + return failure(); + rewriter.replaceOp(op, mult.getValue()); return success(); } @@ -1372,25 +1378,95 @@ Value r = nullptr; if (acc) r = rewriter.create(loc, rhsType, acc, pos); - Value m = genMult(loc, a, op.rhs(), r, isInt, rewriter); - result = rewriter.create(loc, resType, m, result, pos); + Optional m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter) + : genMultF(loc, a, op.rhs(), r, kind, rewriter); + if (!m.hasValue()) + return failure(); + result = rewriter.create(loc, resType, m.getValue(), + result, pos); } rewriter.replaceOp(op, result); return success(); } private: - static Value genMult(Location loc, Value x, Value y, Value acc, bool isInt, - PatternRewriter &rewriter) { - if (acc) { - if (isInt) - return rewriter.create(loc, rewriter.create(loc, x, y), - acc); - return rewriter.create(loc, x, y, acc); + static Optional genMultI(Location loc, Value x, Value y, Value acc, + vector::CombiningKind kind, + PatternRewriter &rewriter) { + using vector::CombiningKind; + + MulIOp mul = rewriter.create(loc, x, y); + if (!acc) + return Optional(mul); + + Value combinedResult; + switch (kind) { + case CombiningKind::ADD: + combinedResult = rewriter.create(loc, mul, acc); + break; + case CombiningKind::MUL: + combinedResult = rewriter.create(loc, mul, acc); + break; + case CombiningKind::MIN: + combinedResult = rewriter.create( + loc, rewriter.create(loc, CmpIPredicate::slt, mul, acc), mul, + acc); + break; + case CombiningKind::MAX: + combinedResult = rewriter.create( + loc, rewriter.create(loc, CmpIPredicate::sge, mul, acc), mul, + acc); + break; + case CombiningKind::AND: + combinedResult = rewriter.create(loc, mul, acc); + break; + case CombiningKind::OR: + combinedResult = rewriter.create(loc, mul, acc); + break; + case CombiningKind::XOR: + combinedResult = rewriter.create(loc, mul, acc); + break; } - if (isInt) - return rewriter.create(loc, x, y); - return rewriter.create(loc, x, y); + return Optional(combinedResult); + } + + static Optional genMultF(Location loc, Value x, Value y, Value acc, + vector::CombiningKind kind, + PatternRewriter &rewriter) { + using vector::CombiningKind; + + // Special case for fused multiply-add. + if (acc && kind == CombiningKind::ADD) { + return Optional(rewriter.create(loc, x, y, acc)); + } + + MulFOp mul = rewriter.create(loc, x, y); + + if (!acc) + return Optional(mul); + + Value combinedResult; + switch (kind) { + case CombiningKind::MUL: + combinedResult = rewriter.create(loc, mul, acc); + break; + case CombiningKind::MIN: + combinedResult = rewriter.create( + loc, rewriter.create(loc, CmpFPredicate::OLE, mul, acc), mul, + acc); + break; + case CombiningKind::MAX: + combinedResult = rewriter.create( + loc, rewriter.create(loc, CmpFPredicate::OGT, mul, acc), mul, + acc); + break; + case CombiningKind::ADD: // Already handled this special case above. + case CombiningKind::AND: // Only valid for integer types. + case CombiningKind::OR: // Only valid for integer types. + case CombiningKind::XOR: // Only valid for integer types. + return Optional(); + } + return Optional(combinedResult); } }; @@ -1805,7 +1881,8 @@ for (int64_t k = 0; k < reductionSize; ++k) { Value a = rewriter.create(op.getLoc(), lhs, k); Value b = rewriter.create(op.getLoc(), rhs, k); - res = rewriter.create(op.getLoc(), a, b, res); + res = rewriter.create(op.getLoc(), res.getType(), a, + b, op.kind(), res); } rewriter.replaceOp(op, res); return success(); diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -67,7 +67,7 @@ // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32> // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32> // CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32> -// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> +// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32> func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>, @@ -86,7 +86,7 @@ // CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32> // CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32> // CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32> -// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32> +// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32> // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32> func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>, @@ -311,6 +311,6 @@ // CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32> // CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32> // CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32> -// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[V2]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32> +// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[V0]], %[[V1]], %[[V2]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32> // CHECK: %[[W:.*]] = vector.transfer_write %[[C]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32> // CHECK: return %[[W]] : tensor<8x12xf32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -198,13 +198,34 @@ func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 { // CHECK: %[[C0:.*]] = constant 0.000000e+00 : f32 %f0 = constant 0.0: f32 - // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"]} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32 + // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32 %0 = vector.contract #contraction_to_scalar_trait %arg0, %arg1, %f0 : vector<10xf32>, vector<10xf32> into f32 // CHECK: return %[[X]] : f32 return %0 : f32 } +#contraction_to_scalar_max_accesses = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (i)>, + affine_map<(i) -> ()> +] +#contraction_to_scalar_max_trait = { + indexing_maps = #contraction_to_scalar_max_accesses, + iterator_types = ["reduction"], + kind = #vector.kind +} +// CHECK-LABEL: @contraction_to_scalar_with_max +func @contraction_to_scalar_with_max(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 { + // CHECK: %[[C0:.*]] = constant 0.000000e+00 : f32 + %f0 = constant 0.0: f32 + // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32 + %0 = vector.contract #contraction_to_scalar_max_trait %arg0, %arg1, %f0 + : vector<10xf32>, vector<10xf32> into f32 + // CHECK: return %[[X]] : f32 + return %0 : f32 +} + #contraction_accesses0 = [ affine_map<(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0)>, affine_map<(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1)>, @@ -221,36 +242,46 @@ // 8, 8, 15, 5 affine_map<(f0, f1, f2, f3, c0, c1) -> (f0, f1, f2, f3)> ] +#iterator_types1 = ["parallel", "parallel", "parallel", "parallel", "reduction", + "reduction"] #contraction_trait1 = { indexing_maps = #contraction_accesses1, - iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", - "reduction"] + iterator_types = #iterator_types1 +} +#contraction_trait2 = { + indexing_maps = #contraction_accesses1, + iterator_types = #iterator_types1, + kind = #vector.kind } // CHECK-LABEL: @contraction func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>, %arg2 : vector<8x15x5xf32>, %arg3 : vector<8x8x15x5xf32>, %arg4 : vector<7x8x16x15xf16>, %arg5 : vector<8x16x7x5xf16>) { // Test contraction with batch and contracting dims. - // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> // Test contraction with only contracting dims. In this case the lhs/rhs // dimension of size 8 will be considered a parallel dim for lhs/rhs and will // appear twice in the output. - // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> %1 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3 : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> // Test contraction with optional vector mask arguments. %lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1> %rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1> - // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> %2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask, %rhs_mask : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> // Test contraction with mixed type. - // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32> + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32> %3 = vector.contract #contraction_trait1 %arg4, %arg5, %arg3 : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32> + // Test contraction with "max" instead of "add". + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> + %4 = vector.contract #contraction_trait2 %arg0, %arg1, %arg3 + : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> return } diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir @@ -9,6 +9,11 @@ indexing_maps = #matvec_accesses, iterator_types = ["parallel", "reduction"] } +#matvecmax_trait = { + indexing_maps = #matvec_accesses, + iterator_types = ["parallel", "reduction"], + kind = #vector.kind +} #mattransvec_accesses = [ affine_map<(i, j) -> (j, i)>, @@ -50,10 +55,10 @@ // CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32> // CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32> -// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] : vector<2xf32>, f32 +// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind} : vector<2xf32>, f32 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32> // CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32> -// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] : vector<2xf32>, f32 +// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind} : vector<2xf32>, f32 // CHECK: store %[[T9]], %[[C]][] : memref> // CHECK: return func @matvec2x2(%arg0: memref>, %arg1: memref>, @@ -66,6 +71,32 @@ return } +// CHECK-LABEL: func @matvecmax2x2 +// CHECK-SAME: %[[A:.*0]]: memref> +// CHECK-SAME: %[[B:.*1]]: memref> +// CHECK-SAME: %[[C:.*2]]: memref> +// CHECK: %[[T0:.*]] = load %[[A]][] : memref> +// CHECK: %[[T1:.*]] = load %[[B]][] : memref> +// CHECK: %[[T2:.*]] = load %[[C]][] : memref> +// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32> +// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32> +// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind} : vector<2xf32>, f32 +// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32> +// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind} : vector<2xf32>, f32 +// CHECK: store %[[T9]], %[[C]][] : memref> +// CHECK: return +func @matvecmax2x2(%arg0: memref>, %arg1: memref>, + %arg2: memref>) { + %A = load %arg0[] : memref> + %x = load %arg1[] : memref> + %b = load %arg2[] : memref> + %0 = vector.contract #matvecmax_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32> + store %0, %arg2[] : memref> + return +} + // CHECK-LABEL: func @mattransvec2x2 // CHECK-SAME: %[[A:.*0]]: memref> // CHECK-SAME: %[[B:.*1]]: memref> @@ -75,10 +106,10 @@ // CHECK: %[[T2:.*]] = load %[[C]][] : memref> // CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32> // CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32> -// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] : vector<2xf32>, f32 +// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind} : vector<2xf32>, f32 // CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32> // CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32> -// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] : vector<2xf32>, f32 +// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind} : vector<2xf32>, f32 // CHECK: store %[[T8]], %[[C]][] : memref> // CHECK: return func @mattransvec2x2(%arg0: memref>, %arg1: memref>, @@ -101,10 +132,10 @@ // CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32> // CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32> -// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] : vector<2xf32>, f32 +// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind} : vector<2xf32>, f32 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32> // CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32> -// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] : vector<2xf32>, f32 +// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind} : vector<2xf32>, f32 // CHECK: store %[[T9]], %[[C]][] : memref> // CHECK: return func @vecmat2x2(%arg0: memref>, %arg1: memref>, @@ -126,10 +157,10 @@ // CHECK: %[[T2:.*]] = load %[[C]][] : memref> // CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32> // CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32> -// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] : vector<2xf32>, f32 +// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind} : vector<2xf32>, f32 // CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32> // CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32> -// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] : vector<2xf32>, f32 +// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind} : vector<2xf32>, f32 // CHECK: store %[[T8]], %[[C]][] : memref> // CHECK: return func @vecmattrans2x2(%arg0: memref>, %arg1: memref>, diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -92,56 +92,56 @@ // CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES3]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES4]], 0 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> // CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES5]], 0 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES1]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES2]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES4]], 1 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> // CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES5]], 2 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG6]], %[[TG7]], %[[R1S00]], %[[TG8]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG6]], %[[TG7]], %[[R1S00]], %[[TG8]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES1]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES2]], 4 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES4]], 2 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> // CHECK-NEXT: %[[TG13:.*]] = vector.tuple_get %[[ES5]], 4 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG10]], %[[TG11]], %[[R2S00]], %[[TG12]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG10]], %[[TG11]], %[[R2S00]], %[[TG12]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [0, 2] // CHECK-NEXT: %[[TG14:.*]] = vector.tuple_get %[[ES2]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG15:.*]] = vector.tuple_get %[[ES3]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG16:.*]] = vector.tuple_get %[[ES5]], 1 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG1]], %[[TG14]], %[[TG15]], %[[TG4]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG1]], %[[TG14]], %[[TG15]], %[[TG4]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[TG17:.*]] = vector.tuple_get %[[ES2]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG18:.*]] = vector.tuple_get %[[ES5]], 3 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG6]], %[[TG17]], %[[R1S02]], %[[TG8]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG6]], %[[TG17]], %[[R1S02]], %[[TG8]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[TG19:.*]] = vector.tuple_get %[[ES2]], 5 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG20:.*]] = vector.tuple_get %[[ES5]], 5 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG10]], %[[TG19]], %[[R2S02]], %[[TG12]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG10]], %[[TG19]], %[[R2S02]], %[[TG12]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 0] // CHECK-NEXT: %[[TG21:.*]] = vector.tuple_get %[[ES1]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG22:.*]] = vector.tuple_get %[[ES3]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG23:.*]] = vector.tuple_get %[[ES4]], 3 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG21]], %[[TG2]], %[[TG22]], %[[TG23]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG21]], %[[TG2]], %[[TG22]], %[[TG23]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[TG24:.*]] = vector.tuple_get %[[ES1]], 4 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG25:.*]] = vector.tuple_get %[[ES4]], 4 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG24]], %[[TG7]], %[[R1S20]], %[[TG25]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG24]], %[[TG7]], %[[R1S20]], %[[TG25]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[TG26:.*]] = vector.tuple_get %[[ES1]], 5 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG27:.*]] = vector.tuple_get %[[ES4]], 5 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG26]], %[[TG11]], %[[R2S20]], %[[TG27]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG26]], %[[TG11]], %[[R2S20]], %[[TG27]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 2] // CHECK-NEXT: %[[TG28:.*]] = vector.tuple_get %[[ES3]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG21]], %[[TG14]], %[[TG28]], %[[TG23]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG24]], %[[TG17]], %[[R1S22]], %[[TG25]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG26]], %[[TG19]], %[[R2S22]], %[[TG27]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG21]], %[[TG14]], %[[TG28]], %[[TG23]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG24]], %[[TG17]], %[[R1S22]], %[[TG25]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG26]], %[[TG19]], %[[R2S22]], %[[TG27]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[RES0:.*]] = vector.tuple %[[R3S00]], %[[R3S02]], %[[R3S20]], %[[R3S22]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> // CHECK-NEXT: %[[RES1:.*]] = vector.insert_slices %[[RES0]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32> @@ -187,26 +187,26 @@ // CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES3]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES4]], 0 : tuple, vector<2x2xi1>> // CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES5]], 0 : tuple, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [0, 2] // CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES2]], 1 : tuple, vector<2x2xf32>> // CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES3]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES5]], 1 : tuple, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG1]], %[[TG6]], %[[TG7]], %[[TG4]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[TG1]], %[[TG6]], %[[TG7]], %[[TG4]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 0] // CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES1]], 1 : tuple, vector<2x2xf32>> // CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES3]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> // CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES4]], 1 : tuple, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG9]], %[[TG2]], %[[TG10]], %[[TG11]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[TG9]], %[[TG2]], %[[TG10]], %[[TG11]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 2] // CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES3]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG9]], %[[TG6]], %[[TG12]], %[[TG11]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[TG9]], %[[TG6]], %[[TG12]], %[[TG11]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[RES0:.*]] = vector.tuple %[[R1S00]], %[[R1S02]], %[[R1S20]], %[[R1S22]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> // CHECK-NEXT: %[[RES1:.*]] = vector.insert_slices %[[RES0]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32> @@ -241,10 +241,10 @@ // CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> // CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32> // CHECK-NEXT: vector.transfer_write %[[R1]], %{{.*}}[%[[C0]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32> @@ -572,10 +572,10 @@ // CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> // CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32> // CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[R1]], %[[VTW0]][%[[C0]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32> diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -199,7 +199,7 @@ if (auto val = enumerant.getValue()) os << formatv(" if ({0}u & val) {{ strs.push_back(\"{1}\"); " "val &= ~{0}u; }\n", - val, enumerant.getSymbol()); + val, enumerant.getStr()); } // If we have unknown bit set, return an empty string to signal errors. os << "\n if (val) return \"\";\n"; @@ -261,8 +261,7 @@ for (const auto &enumerant : enumerants) { // Skip the special enumerant for None. if (auto val = enumerant.getValue()) - os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getSymbol(), - val); + os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getStr(), val); } os.indent(6) << ".Default(::llvm::None);\n";