diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -200,7 +200,7 @@ OpBuilder<(ins "ICmpPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{ build($_builder, $_state, IntegerType::get(lhs.getType().getContext(), 1), - $_builder.getI64IntegerAttr(static_cast(predicate)), lhs, rhs); + predicate, lhs, rhs); }]>]; let parser = [{ return parseCmpOp(parser, result); }]; let printer = [{ printICmpOp(p, *this); }]; @@ -246,14 +246,6 @@ let llvmBuilder = [{ $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); }]; - let builders = [ - OpBuilder<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs, - CArg<"FastmathFlags", "{}">:$fmf), - [{ - build($_builder, $_state, IntegerType::get(lhs.getType().getContext(), 1), - $_builder.getI64IntegerAttr(static_cast(predicate)), lhs, rhs, - ::mlir::LLVM::FMFAttr::get($_builder.getContext(), fmf)); - }]>]; let parser = [{ return parseCmpOp(parser, result); }]; let printer = [{ printFCmpOp(p, *this); }]; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVEnums.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVEnums.h --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVEnums.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVEnums.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SPIRV_IR_SPIRVENUMS_H_ #define MLIR_DIALECT_SPIRV_IR_SPIRVENUMS_H_ +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/StringRef.h" diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td @@ -184,7 +184,7 @@ let builders = [ OpBuilder<(ins "Value":$basePtr, - CArg<"IntegerAttr", "{}">:$memory_access, + CArg<"MemoryAccessAttr", "{}">:$memory_access, CArg<"IntegerAttr", "{}">:$alignment)> ]; } diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -53,6 +53,7 @@ COMBINING_KIND_MAX, COMBINING_KIND_AND, COMBINING_KIND_OR, COMBINING_KIND_XOR]> { let cppNamespace = "::mlir::vector"; + let genSpecializedAttr = 0; } def Vector_CombiningKindAttr : DialectAttr< diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1142,7 +1142,9 @@ } // Additional information for an enum attribute. -class EnumAttrInfo cases> { +class EnumAttrInfo< + string name, list cases, Attr baseClass> : + Attr { // The C++ enum class name string className = name; @@ -1188,6 +1190,28 @@ // static constexpr unsigned (); // ``` string maxEnumValFnName = "getMaxEnumValFor" # name; + + // Generate specialized Attribute class + bit genSpecializedAttr = 1; + // The underlying Attribute class, which holds the enum value + Attr baseAttrClass = baseClass; + // The name of specialized Enum Attribute class + string specializedAttrClassName = name # Attr; + + // Override Attr class fields for specialized class + let predicate = !if(genSpecializedAttr, + CPred<"$_self.isa<" # cppNamespace # "::" # specializedAttrClassName # ">()">, + baseAttrClass.predicate); + let storageType = !if(genSpecializedAttr, + cppNamespace # "::" # specializedAttrClassName, + baseAttrClass.storageType); + let returnType = !if(genSpecializedAttr, + cppNamespace # "::" # className, + baseAttrClass.returnType); + let constBuilderCall = !if(genSpecializedAttr, + cppNamespace # "::" # specializedAttrClassName # "::get($_builder.getContext(), $0)", + baseAttrClass.constBuilderCall); + let valueType = baseAttrClass.valueType; } // An enum attribute backed by StringAttr. @@ -1195,47 +1219,44 @@ // Op attributes of this kind are stored as StringAttr. Extra verification will // be generated on the string though: only the symbols of the allowed cases are // permitted as the string value. -class StrEnumAttr cases> - : EnumAttrInfo, +class StrEnumAttr cases> : + EnumAttrInfo]>, !if(!empty(summary), "allowed string cases: " # !interleave(!foreach(case, cases, "'" # case.symbol # "'"), ", "), - summary)>; + summary)>> { + // Disable specialized Attribute class for `StringAttr` backend by default. + let genSpecializedAttr = 0; +} // An enum attribute backed by IntegerAttr. // // Op attributes of this kind are stored as IntegerAttr. Extra verification will // be generated on the integer though: only the values of the allowed cases are // permitted as the integer value. -class IntEnumAttr cases> : - EnumAttrInfo, - SignlessIntegerAttrBase { +class IntEnumAttrBase cases, string summary> : + SignlessIntegerAttrBase { let predicate = And<[ - SignlessIntegerAttrBase.predicate, + SignlessIntegerAttrBase.predicate, Or]>; } -class I32EnumAttr cases> : +class IntEnumAttr cases> : + EnumAttrInfo>; + +class I32EnumAttr cases> : IntEnumAttr { - let returnType = cppNamespace # "::" # name; let underlyingType = "uint32_t"; - let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; - let constBuilderCall = - "$_builder.getI32IntegerAttr(static_cast($0))"; } -class I64EnumAttr cases> : +class I64EnumAttr cases> : IntEnumAttr { - let returnType = cppNamespace # "::" # name; let underlyingType = "uint64_t"; - let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; - let constBuilderCall = - "$_builder.getI64IntegerAttr(static_cast($0))"; } // A bit enum stored with 32-bit IntegerAttr. @@ -1244,9 +1265,8 @@ // be generated on the integer to make sure only allowed bit are set. Besides, // helper methods are generated to parse a string separated with a specified // delimiter to a symbol and vice versa. -class BitEnumAttr cases> : - EnumAttrInfo, SignlessIntegerAttrBase { +class BitEnumAttrBase cases, string summary> : + SignlessIntegerAttrBase { let predicate = And<[ I32Attr.predicate, // Make sure we don't have unknown bit set. @@ -1254,12 +1274,11 @@ # !interleave(!foreach(case, cases, case.value # "u"), "|") # ")))"> ]>; +} - let returnType = cppNamespace # "::" # name; +class BitEnumAttr cases> : + EnumAttrInfo> { let underlyingType = "uint32_t"; - let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; - let constBuilderCall = - "$_builder.getI32IntegerAttr(static_cast($0))"; // We need to return a string because we may concatenate symbols for multiple // bits together. diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -202,6 +202,10 @@ // Returns all allowed cases for this enum attribute. std::vector getAllCases() const; + + bool genSpecializedAttr() const; + llvm::Record *getBaseAttrClass() const; + StringRef getSpecializedAttrClassName() const; }; class StructFieldAttr { diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -155,9 +155,7 @@ // header to merge. scf::ForOpAdaptor forOperands(operands); auto loc = forOp.getLoc(); - auto loopControl = rewriter.getI32IntegerAttr( - static_cast(spirv::LoopControl::None)); - auto loopOp = rewriter.create(loc, loopControl); + auto loopOp = rewriter.create(loc, spirv::LoopControl::None); loopOp.addEntryAndMergeBlock(); OpBuilder::InsertionGuard guard(rewriter); @@ -238,11 +236,9 @@ scf::IfOpAdaptor ifOperands(operands); auto loc = ifOp.getLoc(); - // Create `spv.mlir.selection` operation, selection header block and merge - // block. - auto selectionControl = rewriter.getI32IntegerAttr( - static_cast(spirv::SelectionControl::None)); - auto selectionOp = rewriter.create(loc, selectionControl); + // Create `spv.selection` operation, selection header block and merge block. + auto selectionOp = + rewriter.create(loc, spirv::SelectionControl::None); auto *mergeBlock = rewriter.createBlock(&selectionOp.body(), selectionOp.body().end()); rewriter.create(loc); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -826,10 +826,8 @@ return failure(); rewriter.template replaceOpWithNewOp( - operation, dstType, - rewriter.getI64IntegerAttr(static_cast(predicate)), - operation.operand1(), operation.operand2(), - LLVM::FMFAttr::get(operation.getContext(), {})); + operation, dstType, predicate, operation.operand1(), + operation.operand2()); return success(); } }; @@ -849,9 +847,8 @@ return failure(); rewriter.template replaceOpWithNewOp( - operation, dstType, - rewriter.getI64IntegerAttr(static_cast(predicate)), - operation.operand1(), operation.operand2()); + operation, dstType, predicate, operation.operand1(), + operation.operand2()); return success(); } }; diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -3069,8 +3069,7 @@ rewriter.replaceOpWithNewOp( cmpiOp, typeConverter->convertType(cmpiOp.getResult().getType()), - rewriter.getI64IntegerAttr(static_cast( - convertCmpPredicate(cmpiOp.getPredicate()))), + convertCmpPredicate(cmpiOp.getPredicate()), transformed.lhs(), transformed.rhs()); return success(); @@ -3085,12 +3084,10 @@ ConversionPatternRewriter &rewriter) const override { CmpFOpAdaptor transformed(operands); - auto fmf = LLVM::FMFAttr::get(cmpfOp.getContext(), {}); rewriter.replaceOpWithNewOp( cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()), - rewriter.getI64IntegerAttr(static_cast( - convertCmpPredicate(cmpfOp.getPredicate()))), - transformed.lhs(), transformed.rhs(), fmf); + convertCmpPredicate(cmpfOp.getPredicate()), + transformed.lhs(), transformed.rhs()); return success(); } diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -1017,7 +1017,7 @@ srcBits, dstBits, rewriter); Value spvLoadOp = rewriter.create( loc, dstType, adjustedPtr, - loadOp->getAttrOfType( + loadOp->getAttrOfType( spirv::attributeName()), loadOp->getAttrOfType("alignment")); diff --git a/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp b/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp --- a/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp @@ -36,7 +36,7 @@ MLIRContext *context = map.getContext(); OpBuilder builder(context); return ParallelLoopDimMapping::get( - builder.getI64IntegerAttr(static_cast(processor)), + ProcessorAttr::get(builder.getContext(), processor), AffineMapAttr::get(map), AffineMapAttr::get(bound), context); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVEnums.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVEnums.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVEnums.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVEnums.cpp @@ -12,6 +12,8 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" +#include "mlir/IR/BuiltinTypes.h" + #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1659,7 +1659,7 @@ spirv::FuncOp function, ArrayRef interfaceVars) { build(builder, state, - builder.getI32IntegerAttr(static_cast(executionModel)), + spirv::ExecutionModelAttr::get(builder.getContext(), executionModel), builder.getSymbolRefAttr(function), builder.getArrayAttr(interfaceVars)); } @@ -1721,7 +1721,7 @@ spirv::ExecutionMode executionMode, ArrayRef params) { build(builder, state, builder.getSymbolRefAttr(function), - builder.getI32IntegerAttr(static_cast(executionMode)), + spirv::ExecutionModeAttr::get(builder.getContext(), executionMode), builder.getI32ArrayAttr(params)); } @@ -2243,10 +2243,10 @@ //===----------------------------------------------------------------------===// void spirv::LoadOp::build(OpBuilder &builder, OperationState &state, - Value basePtr, IntegerAttr memory_access, + Value basePtr, MemoryAccessAttr memoryAccess, IntegerAttr alignment) { auto ptrType = basePtr.getType().cast(); - build(builder, state, ptrType.getPointeeType(), basePtr, memory_access, + build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess, alignment); } @@ -2784,9 +2784,8 @@ spirv::SelectionOp spirv::SelectionOp::createIfThen( Location loc, Value condition, function_ref thenBody, OpBuilder &builder) { - auto selectionControl = builder.getI32IntegerAttr( - static_cast(spirv::SelectionControl::None)); - auto selectionOp = builder.create(loc, selectionControl); + auto selectionOp = + builder.create(loc, spirv::SelectionControl::None); selectionOp.addMergeBlock(); Block *mergeBlock = selectionOp.getMergeBlock(); diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -231,6 +231,18 @@ return cases; } +bool EnumAttr::genSpecializedAttr() const { + return def->getValueAsBit("genSpecializedAttr"); +} + +llvm::Record *EnumAttr::getBaseAttrClass() const { + return def->getValueAsDef("baseAttrClass"); +} + +StringRef EnumAttr::getSpecializedAttrClassName() const { + return def->getValueAsString("specializedAttrClassName"); +} + StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) { assert(def->isSubClassOf("StructFieldAttr") && "must be subclass of TableGen 'StructFieldAttr' class"); diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -331,7 +331,8 @@ return emitError(unknownLoc, "missing Execution Model specification in OpEntryPoint"); } - auto execModel = opBuilder.getI32IntegerAttr(words[wordIndex++]); + auto execModel = spirv::ExecutionModelAttr::get( + context, static_cast(words[wordIndex++])); if (wordIndex >= words.size()) { return emitError(unknownLoc, "missing in OpEntryPoint"); } @@ -383,7 +384,8 @@ if (wordIndex >= words.size()) { return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode"); } - auto execMode = opBuilder.getI32IntegerAttr(words[wordIndex++]); + auto execMode = spirv::ExecutionModeAttr::get( + context, static_cast(words[wordIndex++])); // Get the values SmallVector attrListElems; @@ -417,8 +419,11 @@ argAttrs.push_back(argAttr); } - opBuilder.create(unknownLoc, argAttrs[0], - argAttrs[1], argAttrs[2]); + opBuilder.create( + unknownLoc, argAttrs[0].cast(), + argAttrs[1].cast(), + argAttrs[2].cast()); + return success(); } @@ -483,8 +488,9 @@ argAttrs.push_back(argAttr); } - opBuilder.create(unknownLoc, argAttrs[0], - argAttrs[1]); + opBuilder.create( + unknownLoc, argAttrs[0].cast(), + argAttrs[1].cast()); return success(); } diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1640,7 +1640,7 @@ // merge block so that the newly created SelectionOp will be inserted there. OpBuilder builder(&mergeBlock->front()); - auto control = builder.getI32IntegerAttr(selectionControl); + auto control = static_cast(selectionControl); auto selectionOp = builder.create(location, control); selectionOp.addMergeBlock(); @@ -1652,7 +1652,7 @@ // merge block so that the newly created LoopOp will be inserted there. OpBuilder builder(&mergeBlock->front()); - auto control = builder.getI32IntegerAttr(loopControl); + auto control = static_cast(loopControl); auto loopOp = builder.create(location, control); loopOp.addEntryAndMergeBlock(); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1052,39 +1052,39 @@ } // Test using multi-result op as a whole -def : Pat<(ThreeResultOp MultiResultOpKind1), - (AnotherThreeResultOp MultiResultOpKind1)>; +def : Pat<(ThreeResultOp MultiResultOpKind1:$kind), + (AnotherThreeResultOp $kind)>; // Test using multi-result op as a whole for partial replacement -def : Pattern<(ThreeResultOp MultiResultOpKind2), - [(TwoResultOp MultiResultOpKind2), - (OneResultOp1 MultiResultOpKind2)]>; -def : Pattern<(ThreeResultOp MultiResultOpKind3), - [(OneResultOp2 MultiResultOpKind3), - (AnotherTwoResultOp MultiResultOpKind3)]>; +def : Pattern<(ThreeResultOp MultiResultOpKind2:$kind), + [(TwoResultOp $kind), + (OneResultOp1 $kind)]>; +def : Pattern<(ThreeResultOp MultiResultOpKind3:$kind), + [(OneResultOp2 $kind), + (AnotherTwoResultOp $kind)]>; // Test using results separately in a multi-result op -def : Pattern<(ThreeResultOp MultiResultOpKind4), - [(TwoResultOp:$res1__0 MultiResultOpKind4), - (OneResultOp1 MultiResultOpKind4), - (TwoResultOp:$res2__1 MultiResultOpKind4)]>; +def : Pattern<(ThreeResultOp MultiResultOpKind4:$kind), + [(TwoResultOp:$res1__0 $kind), + (OneResultOp1 $kind), + (TwoResultOp:$res2__1 $kind)]>; // Test referencing a single value in the value pack // This rule only matches TwoResultOp if its second result has no use. -def : Pattern<(TwoResultOp:$res MultiResultOpKind5), - [(OneResultOp2 MultiResultOpKind5), - (OneResultOp1 MultiResultOpKind5)], +def : Pattern<(TwoResultOp:$res MultiResultOpKind5:$kind), + [(OneResultOp2 $kind), + (OneResultOp1 $kind)], [(HasNoUseOf:$res__1)]>; // Test using auxiliary ops for replacing multi-result op def : Pattern< - (ThreeResultOp MultiResultOpKind6), [ + (ThreeResultOp MultiResultOpKind6:$kind), [ // Auxiliary op generated to help building the final result but not // directly used to replace the source op's results. - (TwoResultOp:$interm MultiResultOpKind6), + (TwoResultOp:$interm $kind), (OneResultOp3 $interm__1), - (AnotherTwoResultOp MultiResultOpKind6) + (AnotherTwoResultOp $kind) ]>; //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/TableGen/Attribute.h" +#include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" @@ -22,12 +23,16 @@ using llvm::formatv; using llvm::isDigit; +using llvm::PrintFatalError; using llvm::raw_ostream; using llvm::Record; using llvm::RecordKeeper; using llvm::StringRef; +using mlir::tblgen::Attribute; using mlir::tblgen::EnumAttr; using mlir::tblgen::EnumAttrCase; +using mlir::tblgen::FmtContext; +using mlir::tblgen::tgfmt; static std::string makeIdentifier(StringRef str) { if (!str.empty() && isDigit(static_cast(str.front()))) { @@ -303,6 +308,78 @@ << "}\n\n"; } +static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); + StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); + StringRef attrClassName = enumAttr.getSpecializedAttrClassName(); + llvm::Record *baseAttrDef = enumAttr.getBaseAttrClass(); + Attribute baseAttr(baseAttrDef); + + // Emit classof method + + os << formatv("bool {0}::classof(::mlir::Attribute attr) {{\n", + attrClassName); + + mlir::tblgen::Pred baseAttrPred = baseAttr.getPredicate(); + if (baseAttrPred.isNull()) + PrintFatalError("ERROR: baseAttrClass for EnumAttr has no Predicate\n"); + + std::string condition = baseAttrPred.getCondition(); + FmtContext verifyCtx; + verifyCtx.withSelf("attr"); + os << tgfmt(" return $0;\n", /*ctx=*/nullptr, tgfmt(condition, &verifyCtx)); + + os << "}\n"; + + // Emit get method + + os << formatv("{0} {0}::get(::mlir::MLIRContext *context, {1} val) {{\n", + attrClassName, enumName); + + if (enumAttr.isSubClassOf("StrEnumAttr")) { + os << formatv(" ::mlir::StringAttr baseAttr = " + "::mlir::StringAttr::get(context, {0}(val));\n", + symToStrFnName); + } else { + StringRef underlyingType = enumAttr.getUnderlyingType(); + + // Assuming that it is IntegerAttr constraint + int64_t bitwidth = 64; + if (baseAttrDef->getValue("valueType")) { + auto *valueTypeDef = baseAttrDef->getValueAsDef("valueType"); + if (valueTypeDef->getValue("bitwidth")) + bitwidth = valueTypeDef->getValueAsInt("bitwidth"); + } + + os << formatv(" ::mlir::IntegerType intType = " + "::mlir::IntegerType::get(context, {0});\n", + bitwidth); + os << formatv(" ::mlir::IntegerAttr baseAttr = " + "::mlir::IntegerAttr::get(intType, static_cast<{0}>(val));\n", + underlyingType); + } + os << formatv(" return baseAttr.cast<{0}>();\n", attrClassName); + + os << "}\n"; + + // Emit getValue method + + os << formatv("{0} {1}::getValue() const {{\n", enumName, attrClassName); + + if (enumAttr.isSubClassOf("StrEnumAttr")) { + os << formatv(" const auto res = {0}(::mlir::StringAttr::getValue());\n", + strToSymFnName); + os << " return res.getValue();\n"; + } else { + os << formatv(" return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n", + enumName); + } + + os << "}\n"; +} + static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) { EnumAttr enumAttr(enumDef); @@ -391,6 +468,23 @@ )"; os << formatv(symbolizeEnumStr, enumName, strToSymFnName); + const char *const attrClassDecl = R"( +class {1} : public ::mlir::{2} { +public: + using ValueType = {0}; + using ::mlir::{2}::{2}; + static bool classof(::mlir::Attribute attr); + static {1} get(::mlir::MLIRContext *context, {0} val); + {0} getValue() const; +}; +)"; + if (enumAttr.genSpecializedAttr()) { + StringRef attrClassName = enumAttr.getSpecializedAttrClassName(); + StringRef baseAttrClassName = + enumAttr.isSubClassOf("StrEnumAttr") ? "StringAttr" : "IntegerAttr"; + os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName); + } + for (auto ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; @@ -428,6 +522,9 @@ emitUnderlyingToSymFnForIntEnum(enumDef, os); } + if (enumAttr.genSpecializedAttr()) + emitSpecializedAttrDef(enumDef, os); + for (auto ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; os << "\n"; diff --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp --- a/mlir/unittests/TableGen/EnumsGenTest.cpp +++ b/mlir/unittests/TableGen/EnumsGenTest.cpp @@ -6,21 +6,29 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/Support/LLVM.h" + #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" + #include "gmock/gmock.h" + #include /// Pull in generated enum utility declarations and definitions. #include "EnumsGenTest.h.inc" + #include "EnumsGenTest.cpp.inc" /// Test namespaces and enum class/utility names. using Outer::Inner::ConvertToEnum; using Outer::Inner::ConvertToString; using Outer::Inner::StrEnum; +using Outer::Inner::StrEnumAttr; TEST(EnumsGenTest, GeneratedStrEnumDefinition) { EXPECT_EQ(0u, static_cast(StrEnum::CaseA)); @@ -110,3 +118,41 @@ auto none = symbolizePrettyIntEnum("Case1"); EXPECT_FALSE(none); } + +TEST(EnumsGenTest, GeneratedIntAttributeClass) { + mlir::MLIRContext ctx; + I32Enum rawVal = I32Enum::Case5; + + I32EnumAttr enumAttr = I32EnumAttr::get(&ctx, rawVal); + EXPECT_NE(enumAttr, nullptr); + EXPECT_EQ(enumAttr.getValue(), rawVal); + + mlir::Type intType = mlir::IntegerType::get(&ctx, 32); + mlir::Attribute intAttr = mlir::IntegerAttr::get(intType, 5); + EXPECT_TRUE(intAttr.isa()); + EXPECT_EQ(intAttr, enumAttr); +} + +TEST(EnumsGenTest, GeneratedStringAttributeClass) { + mlir::MLIRContext ctx; + StrEnum rawVal = StrEnum::CaseA; + + StrEnumAttr enumAttr = StrEnumAttr::get(&ctx, rawVal); + EXPECT_NE(enumAttr, nullptr); + EXPECT_EQ(enumAttr.getValue(), rawVal); + + mlir::Attribute strAttr = mlir::StringAttr::get(&ctx, "CaseA"); + EXPECT_TRUE(strAttr.isa()); + EXPECT_EQ(strAttr, enumAttr); +} + +TEST(EnumsGenTest, GeneratedBitAttributeClass) { + mlir::MLIRContext ctx; + + mlir::Type intType = mlir::IntegerType::get(&ctx, 32); + mlir::Attribute intAttr = mlir::IntegerAttr::get( + intType, + static_cast(BitEnumWithNone::Bit1 | BitEnumWithNone::Bit3)); + EXPECT_TRUE(intAttr.isa()); + EXPECT_TRUE(intAttr.isa()); +} diff --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td --- a/mlir/unittests/TableGen/enums.td +++ b/mlir/unittests/TableGen/enums.td @@ -15,6 +15,7 @@ let cppNamespace = "Outer::Inner"; let stringToSymbolFnName = "ConvertToEnum"; let symbolToStringFnName = "ConvertToString"; + let genSpecializedAttr = 1; } def Case5: I32EnumAttrCase<"Case5", 5>;