diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -661,10 +661,7 @@ function_ref mapping) const; /// Method for support type inquiry through isa, cast and dyn_cast. - static bool classof(Attribute attr) { - return attr.getKind() >= StandardAttributes::FIRST_ELEMENTS_ATTR && - attr.getKind() <= StandardAttributes::LAST_ELEMENTS_ATTR; - } + static bool classof(Attribute attr); protected: /// Returns the 1 dimensional flattened row-major index from the given @@ -729,10 +726,7 @@ using ElementsAttr::ElementsAttr; /// Method for support type inquiry through isa, cast and dyn_cast. - static bool classof(Attribute attr) { - return attr.getKind() == StandardAttributes::DenseIntOrFPElements || - attr.getKind() == StandardAttributes::DenseStringElements; - } + static bool classof(Attribute attr); /// Constructs a dense elements attribute from an array of element values. /// Each element attribute value is expected to be an element of 'type'. @@ -1512,13 +1506,11 @@ /// types. template class ProcessFn, typename... Args> - RetT process(Args &... args) const { - switch (attrKind) { - case StandardAttributes::DenseIntOrFPElements: + RetT process(Args &...args) const { + if (attr.isa()) return ProcessFn()(args...); - case StandardAttributes::SparseElements: + if (attr.isa()) return ProcessFn()(args...); - } llvm_unreachable("unexpected attribute kind"); } @@ -1543,22 +1535,21 @@ }; public: - ElementsAttrIterator(const ElementsAttrIterator &rhs) - : attrKind(rhs.attrKind) { + ElementsAttrIterator(const ElementsAttrIterator &rhs) : attr(rhs.attr) { process(it, rhs.it); } ~ElementsAttrIterator() { process(it); } /// Methods necessary to support random access iteration. ptrdiff_t operator-(const ElementsAttrIterator &rhs) const { - assert(attrKind == rhs.attrKind && "incompatible iterators"); + assert(attr == rhs.attr && "incompatible iterators"); return process(it, rhs.it); } bool operator==(const ElementsAttrIterator &rhs) const { - return rhs.attrKind == attrKind && process(it, rhs.it); + return rhs.attr == attr && process(it, rhs.it); } bool operator<(const ElementsAttrIterator &rhs) const { - assert(attrKind == rhs.attrKind && "incompatible iterators"); + assert(attr == rhs.attr && "incompatible iterators"); return process(it, rhs.it); } ElementsAttrIterator &operator+=(ptrdiff_t offset) { @@ -1575,14 +1566,14 @@ private: template - ElementsAttrIterator(unsigned attrKind, IteratorT &&it) - : attrKind(attrKind), it(std::forward(it)) {} + ElementsAttrIterator(Attribute attr, IteratorT &&it) + : attr(attr), it(std::forward(it)) {} /// Allow accessing the constructor. friend ElementsAttr; - /// The kind of derived elements attribute. - unsigned attrKind; + /// The parent elements attribute. + Attribute attr; /// A union containing the specific iterators for each derived kind. Iterator it; @@ -1599,13 +1590,13 @@ auto ElementsAttr::getValues() const -> iterator_range { if (DenseElementsAttr denseAttr = dyn_cast()) { auto values = denseAttr.getValues(); - return {iterator(getKind(), values.begin()), - iterator(getKind(), values.end())}; + return {iterator(*this, values.begin()), + iterator(*this, values.end())}; } if (SparseElementsAttr sparseAttr = dyn_cast()) { auto values = sparseAttr.getValues(); - return {iterator(getKind(), values.begin()), - iterator(getKind(), values.end())}; + return {iterator(*this, values.begin()), + iterator(*this, values.end())}; } llvm_unreachable("unexpected attribute kind"); } diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h --- a/mlir/include/mlir/IR/Location.h +++ b/mlir/include/mlir/IR/Location.h @@ -42,10 +42,7 @@ using Attribute::Attribute; /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(Attribute attr) { - return attr.getKind() >= StandardAttributes::FIRST_LOCATION_ATTR && - attr.getKind() <= StandardAttributes::LAST_LOCATION_ATTR; - } + static bool classof(Attribute attr); }; /// This class defines the main interface for locations in MLIR and acts as a diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1472,18 +1472,15 @@ // ODS already generates checks to make sure the result type is valid. We just // need to additionally check that the value's attribute type is consistent // with the result type. - switch (value.getKind()) { - case StandardAttributes::Integer: - case StandardAttributes::Float: { + if (value.isa()) { if (valueType != opType) return constOp.emitOpError("result type (") << opType << ") does not match value type (" << valueType << ")"; return success(); - } break; - case StandardAttributes::DenseIntOrFPElements: - case StandardAttributes::SparseElements: { + } + if (value.isa()) { if (valueType == opType) - break; + return success(); auto arrayType = opType.dyn_cast(); auto shapedType = valueType.dyn_cast(); if (!arrayType) { @@ -1497,9 +1494,8 @@ numElements *= t.getNumElements(); opElemType = t.getElementType(); } - if (!opElemType.isIntOrFloat()) { + if (!opElemType.isIntOrFloat()) return constOp.emitOpError("only support nested array result type"); - } auto valueElemType = shapedType.getElementType(); if (valueElemType != opElemType) { @@ -1513,26 +1509,24 @@ << numElements << ") does not match value number of elements (" << shapedType.getNumElements() << ")"; } - } break; - case StandardAttributes::Array: { + return success(); + } + if (auto attayAttr = value.dyn_cast()) { auto arrayType = opType.dyn_cast(); if (!arrayType) return constOp.emitOpError( "must have spv.array result type for array value"); - auto elemType = arrayType.getElementType(); - for (auto element : value.cast().getValue()) { + Type elemType = arrayType.getElementType(); + for (Attribute element : attayAttr.getValue()) { if (element.getType() != elemType) return constOp.emitOpError("has array element whose type (") << element.getType() << ") does not match the result element type (" << elemType << ')'; } - } break; - default: - return constOp.emitOpError("cannot have value of type ") << valueType; + return success(); } - - return success(); + return constOp.emitOpError("cannot have value of type ") << valueType; } bool spirv::ConstantOp::isBuildableWith(Type type) { @@ -2619,19 +2613,14 @@ return constOp.emitOpError("SpecId cannot be negative"); auto value = constOp.default_value(); - - switch (value.getKind()) { - case StandardAttributes::Integer: - case StandardAttributes::Float: { + if (value.isa()) { // Make sure bitwidth is allowed. if (!value.getType().isa()) return constOp.emitOpError("default value bitwidth disallowed"); return success(); } - default: - return constOp.emitOpError( - "default value can only be a bool, integer, or float scalar"); - } + return constOp.emitOpError( + "default value can only be a bool, integer, or float scalar"); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -33,6 +33,7 @@ #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Regex.h" #include "llvm/Support/SaveAndRestore.h" @@ -1019,76 +1020,67 @@ } void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) { - switch (loc.getKind()) { - case StandardAttributes::OpaqueLocation: - printLocationInternal(loc.cast().getFallbackLocation(), pretty); - break; - case StandardAttributes::UnknownLocation: - if (pretty) - os << "[unknown]"; - else - os << "unknown"; - break; - case StandardAttributes::FileLineColLocation: { - auto fileLoc = loc.cast(); - auto mayQuote = pretty ? "" : "\""; - os << mayQuote << fileLoc.getFilename() << mayQuote << ':' - << fileLoc.getLine() << ':' << fileLoc.getColumn(); - break; - } - case StandardAttributes::NameLocation: { - auto nameLoc = loc.cast(); - os << '\"' << nameLoc.getName() << '\"'; - - // Print the child if it isn't unknown. - auto childLoc = nameLoc.getChildLoc(); - if (!childLoc.isa()) { - os << '('; - printLocationInternal(childLoc, pretty); - os << ')'; - } - break; - } - case StandardAttributes::CallSiteLocation: { - auto callLocation = loc.cast(); - auto caller = callLocation.getCaller(); - auto callee = callLocation.getCallee(); - if (!pretty) - os << "callsite("; - printLocationInternal(callee, pretty); - if (pretty) { - if (callee.isa()) { - if (caller.isa()) { - os << " at "; + TypeSwitch(loc) + .Case([&](OpaqueLoc loc) { + printLocationInternal(loc.getFallbackLocation(), pretty); + }) + .Case([&](UnknownLoc loc) { + if (pretty) + os << "[unknown]"; + else + os << "unknown"; + }) + .Case([&](FileLineColLoc loc) { + StringRef mayQuote = pretty ? "" : "\""; + os << mayQuote << loc.getFilename() << mayQuote << ':' << loc.getLine() + << ':' << loc.getColumn(); + }) + .Case([&](NameLoc loc) { + os << '\"' << loc.getName() << '\"'; + + // Print the child if it isn't unknown. + auto childLoc = loc.getChildLoc(); + if (!childLoc.isa()) { + os << '('; + printLocationInternal(childLoc, pretty); + os << ')'; + } + }) + .Case([&](CallSiteLoc loc) { + Location caller = loc.getCaller(); + Location callee = loc.getCallee(); + if (!pretty) + os << "callsite("; + printLocationInternal(callee, pretty); + if (pretty) { + if (callee.isa()) { + if (caller.isa()) { + os << " at "; + } else { + os << newLine << " at "; + } + } else { + os << newLine << " at "; + } } else { - os << newLine << " at "; + os << " at "; } - } else { - os << newLine << " at "; - } - } else { - os << " at "; - } - printLocationInternal(caller, pretty); - if (!pretty) - os << ")"; - break; - } - case StandardAttributes::FusedLocation: { - auto fusedLoc = loc.cast(); - if (!pretty) - os << "fused"; - if (auto metadata = fusedLoc.getMetadata()) - os << '<' << metadata << '>'; - os << '['; - interleave( - fusedLoc.getLocations(), - [&](Location loc) { printLocationInternal(loc, pretty); }, - [&]() { os << ", "; }); - os << ']'; - break; - } - } + printLocationInternal(caller, pretty); + if (!pretty) + os << ")"; + }) + .Case([&](FusedLoc loc) { + if (!pretty) + os << "fused"; + if (auto metadata = loc.getMetadata()) + os << '<' << metadata << '>'; + os << '['; + interleave( + loc.getLocations(), + [&](Location loc) { printLocationInternal(loc, pretty); }, + [&]() { os << ", "; }); + os << ']'; + }); } /// Print a floating point value in a way that the parser will be able to @@ -1305,27 +1297,19 @@ } auto attrType = attr.getType(); - switch (attr.getKind()) { - default: - return printDialectAttribute(attr); - - case StandardAttributes::Opaque: { - auto opaqueAttr = attr.cast(); + if (auto opaqueAttr = attr.dyn_cast()) { printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(), opaqueAttr.getAttrData()); - break; - } - case StandardAttributes::Unit: + } else if (attr.isa()) { os << "unit"; - break; - case StandardAttributes::Dictionary: + return; + } else if (auto dictAttr = attr.dyn_cast()) { os << '{'; - interleaveComma(attr.cast().getValue(), + interleaveComma(dictAttr.getValue(), [&](NamedAttribute attr) { printNamedAttribute(attr); }); os << '}'; - break; - case StandardAttributes::Integer: { - auto intAttr = attr.cast(); + + } else if (auto intAttr = attr.dyn_cast()) { if (attrType.isSignlessInteger(1)) { os << (intAttr.getValue().getBoolValue() ? "true" : "false"); @@ -1343,114 +1327,98 @@ // IntegerAttr elides the type if I64. if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64)) return; - break; - } - case StandardAttributes::Float: { - auto floatAttr = attr.cast(); + + } else if (auto floatAttr = attr.dyn_cast()) { printFloatValue(floatAttr.getValue(), os); // FloatAttr elides the type if F64. if (typeElision == AttrTypeElision::May && attrType.isF64()) return; - break; - } - case StandardAttributes::String: + + } else if (auto strAttr = attr.dyn_cast()) { os << '"'; - printEscapedString(attr.cast().getValue(), os); + printEscapedString(strAttr.getValue(), os); os << '"'; - break; - case StandardAttributes::Array: + + } else if (auto arrayAttr = attr.dyn_cast()) { os << '['; - interleaveComma(attr.cast().getValue(), [&](Attribute attr) { + interleaveComma(arrayAttr.getValue(), [&](Attribute attr) { printAttribute(attr, AttrTypeElision::May); }); os << ']'; - break; - case StandardAttributes::AffineMap: + + } else if (auto affineMapAttr = attr.dyn_cast()) { os << "affine_map<"; - attr.cast().getValue().print(os); + affineMapAttr.getValue().print(os); os << '>'; // AffineMap always elides the type. return; - case StandardAttributes::IntegerSet: + + } else if (auto integerSetAttr = attr.dyn_cast()) { os << "affine_set<"; - attr.cast().getValue().print(os); + integerSetAttr.getValue().print(os); os << '>'; // IntegerSet always elides the type. return; - case StandardAttributes::Type: - printType(attr.cast().getValue()); - break; - case StandardAttributes::SymbolRef: { - auto refAttr = attr.dyn_cast(); + + } else if (auto typeAttr = attr.dyn_cast()) { + printType(typeAttr.getValue()); + + } else if (auto refAttr = attr.dyn_cast()) { printSymbolReference(refAttr.getRootReference(), os); for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) { os << "::"; printSymbolReference(nestedRef.getValue(), os); } - break; - } - case StandardAttributes::OpaqueElements: { - auto eltsAttr = attr.cast(); - if (printerFlags.shouldElideElementsAttr(eltsAttr)) { + + } else if (auto opaqueAttr = attr.dyn_cast()) { + if (printerFlags.shouldElideElementsAttr(opaqueAttr)) { printElidedElementsAttr(os); - break; + } else { + os << "opaque<\"" << opaqueAttr.getDialect()->getNamespace() << "\", "; + os << '"' << "0x" << llvm::toHex(opaqueAttr.getValue()) << "\">"; } - os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", "; - os << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << "\">"; - break; - } - case StandardAttributes::DenseIntOrFPElements: { - auto eltsAttr = attr.cast(); - if (printerFlags.shouldElideElementsAttr(eltsAttr)) { + + } else if (auto intOrFpEltAttr = attr.dyn_cast()) { + if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) { printElidedElementsAttr(os); - break; + } else { + os << "dense<"; + printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true); + os << '>'; } - os << "dense<"; - printDenseIntOrFPElementsAttr(eltsAttr, /*allowHex=*/true); - os << '>'; - break; - } - case StandardAttributes::DenseStringElements: { - auto eltsAttr = attr.cast(); - if (printerFlags.shouldElideElementsAttr(eltsAttr)) { + + } else if (auto strEltAttr = attr.dyn_cast()) { + if (printerFlags.shouldElideElementsAttr(strEltAttr)) { printElidedElementsAttr(os); - break; + } else { + os << "dense<"; + printDenseStringElementsAttr(strEltAttr); + os << '>'; } - os << "dense<"; - printDenseStringElementsAttr(eltsAttr); - os << '>'; - break; - } - case StandardAttributes::SparseElements: { - auto elementsAttr = attr.cast(); - if (printerFlags.shouldElideElementsAttr(elementsAttr.getIndices()) || - printerFlags.shouldElideElementsAttr(elementsAttr.getValues())) { + + } else if (auto sparseEltAttr = attr.dyn_cast()) { + if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) || + printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) { printElidedElementsAttr(os); - break; - } - os << "sparse<"; - DenseIntElementsAttr indices = elementsAttr.getIndices(); - if (indices.getNumElements() != 0) { - printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false); - os << ", "; - printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true); + } else { + os << "sparse<"; + DenseIntElementsAttr indices = sparseEltAttr.getIndices(); + if (indices.getNumElements() != 0) { + printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false); + os << ", "; + printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true); + } + os << '>'; } - os << '>'; - break; - } - // Location attributes. - case StandardAttributes::CallSiteLocation: - case StandardAttributes::FileLineColLocation: - case StandardAttributes::FusedLocation: - case StandardAttributes::NameLocation: - case StandardAttributes::OpaqueLocation: - case StandardAttributes::UnknownLocation: - printLocation(attr.cast()); - break; + } else if (auto locAttr = attr.dyn_cast()) { + printLocation(locAttr); + + } else { + return printDialectAttribute(attr); } // Don't print the type if we must elide it, or if it is a None type. diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -460,16 +460,11 @@ /// Return the value at the given index. If index does not refer to a valid /// element, then a null attribute is returned. Attribute ElementsAttr::getValue(ArrayRef index) const { - switch (getKind()) { - case StandardAttributes::DenseIntOrFPElements: - return cast().getValue(index); - case StandardAttributes::OpaqueElements: - return cast().getValue(index); - case StandardAttributes::SparseElements: - return cast().getValue(index); - default: - llvm_unreachable("unknown ElementsAttr kind"); - } + if (auto denseAttr = dyn_cast()) + return denseAttr.getValue(index); + if (auto opaqueAttr = dyn_cast()) + return opaqueAttr.getValue(index); + return cast().getValue(index); } /// Return if the given 'index' refers to a valid element in this attribute. @@ -491,23 +486,23 @@ ElementsAttr ElementsAttr::mapValues(Type newElementType, function_ref mapping) const { - switch (getKind()) { - case StandardAttributes::DenseIntOrFPElements: - return cast().mapValues(newElementType, mapping); - default: - llvm_unreachable("unsupported ElementsAttr subtype"); - } + if (auto intOrFpAttr = dyn_cast()) + return intOrFpAttr.mapValues(newElementType, mapping); + llvm_unreachable("unsupported ElementsAttr subtype"); } ElementsAttr ElementsAttr::mapValues(Type newElementType, function_ref mapping) const { - switch (getKind()) { - case StandardAttributes::DenseIntOrFPElements: - return cast().mapValues(newElementType, mapping); - default: - llvm_unreachable("unsupported ElementsAttr subtype"); - } + if (auto intOrFpAttr = dyn_cast()) + return intOrFpAttr.mapValues(newElementType, mapping); + llvm_unreachable("unsupported ElementsAttr subtype"); +} + +/// Method for support type inquiry through isa, cast and dyn_cast. +bool ElementsAttr::classof(Attribute attr) { + return attr.isa(); } /// Returns the 1 dimensional flattened row-major index from the given @@ -718,6 +713,11 @@ // DenseElementsAttr //===----------------------------------------------------------------------===// +/// Method for support type inquiry through isa, cast and dyn_cast. +bool DenseElementsAttr::classof(Attribute attr) { + return attr.isa(); +} + DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef values) { assert(hasSameElementsOrSplat(type, values)); diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -366,43 +366,38 @@ /// Return a processable FileLineColLoc from the given location. static Optional getFileLineColLoc(Location loc) { - switch (loc->getKind()) { - case StandardAttributes::NameLocation: + if (auto nameLoc = loc.dyn_cast()) return getFileLineColLoc(loc.cast().getChildLoc()); - case StandardAttributes::FileLineColLocation: - return loc.cast(); - case StandardAttributes::CallSiteLocation: - // Process the callee of a callsite location. + if (auto fileLoc = loc.dyn_cast()) + return fileLoc; + if (auto callLoc = loc.dyn_cast()) return getFileLineColLoc(loc.cast().getCallee()); - case StandardAttributes::FusedLocation: + if (auto fusedLoc = loc.dyn_cast()) { for (auto subLoc : loc.cast().getLocations()) { if (auto callLoc = getFileLineColLoc(subLoc)) { return callLoc; } } return llvm::None; - default: - return llvm::None; } + return llvm::None; } /// Return a processable CallSiteLoc from the given location. static Optional getCallSiteLoc(Location loc) { - switch (loc->getKind()) { - case StandardAttributes::NameLocation: + if (auto nameLoc = loc.dyn_cast()) return getCallSiteLoc(loc.cast().getChildLoc()); - case StandardAttributes::CallSiteLocation: - return loc.cast(); - case StandardAttributes::FusedLocation: + if (auto callLoc = loc.dyn_cast()) + return callLoc; + if (auto fusedLoc = loc.dyn_cast()) { for (auto subLoc : loc.cast().getLocations()) { if (auto callLoc = getCallSiteLoc(subLoc)) { return callLoc; } } return llvm::None; - default: - return llvm::None; } + return llvm::None; } /// Given a diagnostic kind, returns the LLVM DiagKind. diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp --- a/mlir/lib/IR/Location.cpp +++ b/mlir/lib/IR/Location.cpp @@ -13,6 +13,16 @@ using namespace mlir; using namespace mlir::detail; +//===----------------------------------------------------------------------===// +// LocationAttr +//===----------------------------------------------------------------------===// + +/// Methods for support type inquiry through isa, cast, and dyn_cast. +bool LocationAttr::classof(Attribute attr) { + return attr.isa(); +} + //===----------------------------------------------------------------------===// // CallSiteLoc //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp --- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp @@ -115,26 +115,19 @@ return existingIt->second; const llvm::DILocation *llvmLoc = nullptr; - switch (loc->getKind()) { - case StandardAttributes::CallSiteLocation: { - auto callLoc = loc.dyn_cast(); - + if (auto callLoc = loc.dyn_cast()) { // For callsites, the caller is fed as the inlinedAt for the callee. const auto *callerLoc = translateLoc(callLoc.getCaller(), scope, inlinedAt); llvmLoc = translateLoc(callLoc.getCallee(), scope, callerLoc); - break; - } - case StandardAttributes::FileLineColLocation: { - auto fileLoc = loc.dyn_cast(); + + } else if (auto fileLoc = loc.dyn_cast()) { auto *file = translateFile(fileLoc.getFilename()); auto *fileScope = builder.createLexicalBlockFile(scope, file); llvmLoc = llvm::DILocation::get(llvmCtx, fileLoc.getLine(), fileLoc.getColumn(), fileScope, const_cast(inlinedAt)); - break; - } - case StandardAttributes::FusedLocation: { - auto fusedLoc = loc.dyn_cast(); + + } else if (auto fusedLoc = loc.dyn_cast()) { ArrayRef locations = fusedLoc.getLocations(); // For fused locations, merge each of the nodes. @@ -143,18 +136,17 @@ llvmLoc = llvm::DILocation::getMergedLocation( llvmLoc, translateLoc(locIt, scope, inlinedAt)); } - break; - } - case StandardAttributes::NameLocation: + + } else if (auto nameLoc = loc.dyn_cast()) { llvmLoc = translateLoc(loc.cast().getChildLoc(), scope, inlinedAt); - break; - case StandardAttributes::OpaqueLocation: + + } else if (auto opaqueLoc = loc.dyn_cast()) { llvmLoc = translateLoc(loc.cast().getFallbackLocation(), scope, inlinedAt); - break; - default: + } else { llvm_unreachable("unknown location kind"); } + locationToLoc.try_emplace(std::make_pair(loc, scope), llvmLoc); return llvmLoc; } diff --git a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp --- a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp +++ b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp @@ -158,7 +158,7 @@ auto expectedTensorType = realValue.getType().cast(); EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape()); EXPECT_EQ(tensorType.getElementType(), convertedType); - EXPECT_EQ(returnedValue.getKind(), StandardAttributes::SparseElements); + EXPECT_TRUE(returnedValue.isa()); // Check Elements attribute element value is expected. auto firstValue = returnedValue.cast().getValue({0, 0});