diff --git a/mlir/include/mlir/Dialect/Vector/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Vector/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Vector/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Vector/IR/CMakeLists.txt @@ -4,5 +4,7 @@ set(LLVM_TARGET_DEFINITIONS VectorOps.td) mlir_tablegen(VectorOpsEnums.h.inc -gen-enum-decls) mlir_tablegen(VectorOpsEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(VectorOpsAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(VectorOpsAttrDefs.cpp.inc -gen-attrdef-defs) add_public_tablegen_target(MLIRVectorOpsEnumsIncGen) add_dependencies(mlir-headers MLIRVectorOpsEnumsIncGen) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -29,6 +29,9 @@ // Pull in all enum type definitions and utility function declarations. #include "mlir/Dialect/Vector/IR/VectorOpsEnums.h.inc" +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.h.inc" + namespace mlir { class MLIRContext; class RewritePatternSet; @@ -113,22 +116,6 @@ /// chain. void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns); -/// An attribute that specifies the combining function for `vector.contract`, -/// and `vector.reduction`. -class CombiningKindAttr - : public Attribute::AttrBase { -public: - using Base::Base; - - static CombiningKindAttr get(CombiningKind kind, MLIRContext *context); - - CombiningKind getKind() const; - - void print(AsmPrinter &p) const; - static Attribute parse(AsmParser &parser, Type type); -}; - /// Collects patterns to progressively lower vector.broadcast ops on high-D /// vectors to low-D vector ops. void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -58,15 +58,10 @@ let genSpecializedAttr = 0; } -def Vector_CombiningKindAttr : DialectAttr< - Vector_Dialect, - CPred<"$_self.isa<::mlir::vector::CombiningKindAttr>()">, - "Kind of combining function for contractions and reductions"> { - let storageType = "::mlir::vector::CombiningKindAttr"; - let returnType = "::mlir::vector::CombiningKind"; - let convertFromStorage = "$_self.getKind()"; - let constBuilderCall = - "::mlir::vector::CombiningKindAttr::get($0, $_builder.getContext())"; +/// An attribute that specifies the combining function for `vector.contract`, +/// and `vector.reduction`. +def Vector_CombiningKindAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; } // TODO: Add an attribute to specify a different algebra with operators other diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -13,6 +13,11 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include + +#include "llvm/ADT/StringSet.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/ADT/bit.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -30,10 +35,6 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" -#include "mlir/Support/MathExtras.h" -#include "llvm/ADT/StringSet.h" -#include "llvm/ADT/bit.h" -#include #include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc" // Pull in all enum type and utility function definitions. @@ -227,91 +228,15 @@ } // namespace vector } // namespace mlir -CombiningKindAttr CombiningKindAttr::get(CombiningKind kind, - MLIRContext *context) { - return Base::get(context, static_cast(kind)); -} - -CombiningKind CombiningKindAttr::getKind() const { - return static_cast(getImpl()->value); -} - -static constexpr const CombiningKind combiningKindsList[] = { - // clang-format off - CombiningKind::ADD, - CombiningKind::MUL, - CombiningKind::MINUI, - CombiningKind::MINSI, - CombiningKind::MINF, - CombiningKind::MAXUI, - CombiningKind::MAXSI, - CombiningKind::MAXF, - CombiningKind::AND, - CombiningKind::OR, - CombiningKind::XOR, - // clang-format on -}; - -void CombiningKindAttr::print(AsmPrinter &printer) const { - printer << "<"; - auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) { - return bitEnumContains(this->getKind(), kind); - }); - llvm::interleaveComma(kinds, printer, - [&](auto kind) { printer << stringifyEnum(kind); }); - printer << ">"; -} - -Attribute CombiningKindAttr::parse(AsmParser &parser, Type type) { - if (failed(parser.parseLess())) - return {}; - - StringRef elemName; - if (failed(parser.parseKeyword(&elemName))) - return {}; - - auto kind = symbolizeCombiningKind(elemName); - if (!kind) { - parser.emitError(parser.getNameLoc(), "Unknown combining kind: ") - << elemName; - return {}; - } - - if (failed(parser.parseGreater())) - return {}; - - return CombiningKindAttr::get(*kind, parser.getContext()); -} - -Attribute VectorDialect::parseAttribute(DialectAsmParser &parser, - Type type) const { - StringRef attrKind; - if (parser.parseKeyword(&attrKind)) - return {}; - - if (attrKind == "kind") - return CombiningKindAttr::parse(parser, {}); - - parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind; - return {}; -} - -void VectorDialect::printAttribute(Attribute attr, - DialectAsmPrinter &os) const { - if (auto ck = attr.dyn_cast()) { - os << "kind"; - ck.print(os); - return; - } - llvm_unreachable("Unknown attribute type"); -} - //===----------------------------------------------------------------------===// // VectorDialect //===----------------------------------------------------------------------===// void VectorDialect::initialize() { - addAttributes(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc" + >(); addOperations< #define GET_OP_LIST @@ -558,7 +483,7 @@ result.addAttribute(::mlir::getIndexingMapsAttrName(), indexingMaps); result.addAttribute(::mlir::getIteratorTypesAttrName(), iteratorTypes); result.addAttribute(ContractionOp::getKindAttrStrName(), - CombiningKindAttr::get(kind, builder.getContext())); + CombiningKindAttr::get(builder.getContext(), kind)); } ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) { @@ -588,8 +513,8 @@ dictAttr.getValue().end()); if (!result.attributes.get(ContractionOp::getKindAttrStrName())) { result.addAttribute(ContractionOp::getKindAttrStrName(), - CombiningKindAttr::get(ContractionOp::getDefaultKind(), - result.getContext())); + CombiningKindAttr::get(result.getContext(), + ContractionOp::getDefaultKind())); } if (masksInfo.empty()) return success(); @@ -2385,8 +2310,8 @@ if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) { result.attributes.append( OuterProductOp::getKindAttrStrName(), - CombiningKindAttr::get(OuterProductOp::getDefaultKind(), - result.getContext())); + CombiningKindAttr::get(result.getContext(), + OuterProductOp::getDefaultKind())); } return failure( @@ -5179,5 +5104,8 @@ // TableGen'd op method definitions //===----------------------------------------------------------------------===// +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc" + #define GET_OP_CLASSES #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc" diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1111,7 +1111,8 @@ // ----- func.func @reduce_unknown_kind(%arg0: vector<16xf32>) -> f32 { - // expected-error@+1 {{custom op 'vector.reduction' Unknown combining kind: joho}} + // expected-error@+2 {{custom op 'vector.reduction' failed to parse Vector_CombiningKindAttr parameter 'value' which is to be a `::mlir::vector::CombiningKind`}} + // expected-error@+1 {{custom op 'vector.reduction' expected ::mlir::vector::CombiningKind to be one of: }} %0 = vector.reduction , %arg0 : vector<16xf32> into f32 } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7698,6 +7698,14 @@ ["-gen-enum-defs"], "include/mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc", ), + ( + ["-gen-attrdef-decls"], + "include/mlir/Dialect/Vector/IR/VectorOpsAttrDefs.h.inc", + ), + ( + ["-gen-attrdef-defs"], + "include/mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc", + ), ( ["-gen-op-doc"], "g3doc/Dialects/Vector/VectorOps.md",