diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -403,7 +403,7 @@ if (Operation *op = val.getDefiningOp()) { setInsertionPointAfter(op); } else { - auto blockArg = val.cast(); + auto blockArg = llvm::cast(val); setInsertionPointToStart(blockArg.getOwner()); } } diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -389,7 +389,7 @@ !std::is_same::value, T> getSplatValue() const { - return getSplatValue().template cast(); + return llvm::cast(getSplatValue()); } /// Try to get an iterator of the given type to the start of the held element @@ -510,7 +510,7 @@ T>::mapped_iterator_base; /// Map the element to the iterator result type. - T mapElement(Attribute attr) const { return attr.cast(); } + T mapElement(Attribute attr) const { return llvm::cast(attr); } }; template > FailureOr>> @@ -684,7 +684,7 @@ /// Method for support type inquiry through isa, cast and dyn_cast. static bool classof(Attribute attr) { - auto denseAttr = attr.dyn_cast(); + auto denseAttr = llvm::dyn_cast(attr); return denseAttr && denseAttr.isSplat(); } }; @@ -887,7 +887,7 @@ /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(Attribute attr) { - SymbolRefAttr refAttr = attr.dyn_cast(); + SymbolRefAttr refAttr = llvm::dyn_cast(attr); return refAttr && refAttr.getNestedReferences().empty(); } @@ -912,14 +912,13 @@ /// simply wraps the DenseElementsAttr::get calls. template static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg) { - return DenseElementsAttr::get(type, llvm::ArrayRef(arg)) - .template cast(); + return llvm::cast( + DenseElementsAttr::get(type, llvm::ArrayRef(arg))); } template static DenseFPElementsAttr get(const ShapedType &type, const std::initializer_list &list) { - return DenseElementsAttr::get(type, list) - .template cast(); + return llvm::cast(DenseElementsAttr::get(type, list)); } /// Generates a new DenseElementsAttr by mapping each value attribute, and @@ -954,14 +953,13 @@ /// simply wraps the DenseElementsAttr::get calls. template static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg) { - return DenseElementsAttr::get(type, llvm::ArrayRef(arg)) - .template cast(); + return llvm::cast( + DenseElementsAttr::get(type, llvm::ArrayRef(arg))); } template static DenseIntElementsAttr get(const ShapedType &type, const std::initializer_list &list) { - return DenseElementsAttr::get(type, list) - .template cast(); + return llvm::cast(DenseElementsAttr::get(type, list)); } /// Generates a new DenseElementsAttr by mapping each value attribute, and diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -367,20 +367,21 @@ //===----------------------------------------------------------------------===// inline bool BaseMemRefType::classof(Type type) { - return type.isa(); + return llvm::isa(type); } inline bool BaseMemRefType::isValidElementType(Type type) { return type.isIntOrIndexOrFloat() || - type.isa() || - type.isa(); + llvm::isa( + type) || + llvm::isa(type); } inline bool FloatType::classof(Type type) { - return type - .isa(); + return llvm::isa(type); } inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) { @@ -428,7 +429,7 @@ } inline bool TensorType::classof(Type type) { - return type.isa(); + return llvm::isa(type); } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/FunctionInterfaces.h b/mlir/include/mlir/IR/FunctionInterfaces.h --- a/mlir/include/mlir/IR/FunctionInterfaces.h +++ b/mlir/include/mlir/IR/FunctionInterfaces.h @@ -178,7 +178,7 @@ } for (unsigned i = 0; i != numArgs; ++i) { DictionaryAttr argAttrs = - allArgAttrs[i].dyn_cast_or_null(); + llvm::dyn_cast_or_null(allArgAttrs[i]); if (!argAttrs) { return op.emitOpError() << "expects argument attribute dictionary " "to be a DictionaryAttr, but got `" @@ -209,7 +209,7 @@ } for (unsigned i = 0; i != numResults; ++i) { DictionaryAttr resultAttrs = - allResultAttrs[i].dyn_cast_or_null(); + llvm::dyn_cast_or_null(allResultAttrs[i]); if (!resultAttrs) { return op.emitOpError() << "expects result attribute dictionary " "to be a DictionaryAttr, but got `" diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h --- a/mlir/include/mlir/IR/Location.h +++ b/mlir/include/mlir/IR/Location.h @@ -148,12 +148,12 @@ /// Return the metadata associated with this fused location. MetadataT getMetadata() const { - return FusedLoc::getMetadata().template cast(); + return llvm::cast(FusedLoc::getMetadata()); } /// Support llvm style casting. static bool classof(Attribute attr) { - auto fusedLoc = attr.dyn_cast(); + auto fusedLoc = llvm::dyn_cast(attr); return fusedLoc && fusedLoc.getMetadata().isa_and_nonnull(); } }; diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -39,7 +39,7 @@ attr_value_binder(ValueType *bv) : bind_value(bv) {} bool match(const Attribute &attr) { - if (auto intAttr = attr.dyn_cast()) { + if (auto intAttr = llvm::dyn_cast(attr)) { *bind_value = intAttr.getValue(); return true; } @@ -90,7 +90,7 @@ (void)result; assert(succeeded(result) && "expected ConstantLike op to be foldable"); - if (auto attr = foldedOp.front().get().dyn_cast()) { + if (auto attr = llvm::dyn_cast(foldedOp.front().get())) { if (bind_value) *bind_value = attr; return true; @@ -136,10 +136,10 @@ return false; auto type = op->getResult(0).getType(); - if (type.isa()) + if (llvm::isa(type)) return attr_value_binder(bind_value).match(attr); - if (type.isa()) { - if (auto splatAttr = attr.dyn_cast()) { + if (llvm::isa(type)) { + if (auto splatAttr = llvm::dyn_cast(attr)) { return attr_value_binder(bind_value) .match(splatAttr.getSplatValue()); } @@ -173,10 +173,10 @@ return false; auto type = op->getResult(0).getType(); - if (type.isa()) + if (llvm::isa(type)) return attr_value_binder(bind_value).match(attr); - if (type.isa()) { - if (auto splatAttr = attr.dyn_cast()) { + if (llvm::isa(type)) { + if (auto splatAttr = llvm::dyn_cast(attr)) { return attr_value_binder(bind_value) .match(splatAttr.getSplatValue()); } diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -204,7 +204,7 @@ auto &os = getStream() << " -> "; bool wrapped = !llvm::hasSingleElement(types) || - (*types.begin()).template isa(); + llvm::isa((*types.begin())); if (wrapped) os << '('; llvm::interleaveComma(types, *this); @@ -865,7 +865,7 @@ return failure(); // Check for the right kind of attribute. - if (!(result = attr.dyn_cast())) + if (!(result = llvm::dyn_cast(attr))) return emitError(loc, "invalid kind of attribute specified"); return success(); @@ -899,7 +899,7 @@ return failure(); // Check for the right kind of attribute. - result = attr.dyn_cast(); + result = llvm::dyn_cast(attr); if (!result) return emitError(loc, "invalid kind of attribute specified"); @@ -936,7 +936,7 @@ return failure(); // Check for the right kind of attribute. - result = attr.dyn_cast(); + result = llvm::dyn_cast(attr); if (!result) return emitError(loc, "invalid kind of attribute specified"); @@ -970,7 +970,7 @@ return failure(); // Check for the right kind of attribute. - result = attr.dyn_cast(); + result = llvm::dyn_cast(attr); if (!result) return emitError(loc, "invalid kind of attribute specified"); return success(); @@ -1126,7 +1126,7 @@ return failure(); // Check for the right kind of type. - result = type.dyn_cast(); + result = llvm::dyn_cast(type); if (!result) return emitError(loc, "invalid kind of type specified"); @@ -1158,7 +1158,7 @@ return failure(); // Check for the right kind of Type. - result = type.dyn_cast(); + result = llvm::dyn_cast(type); if (!result) return emitError(loc, "invalid kind of Type specified"); return success(); @@ -1198,7 +1198,7 @@ return failure(); // Check for the right kind of type. - result = type.dyn_cast(); + result = llvm::dyn_cast(type); if (!result) return emitError(loc, "invalid kind of type specified"); diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -509,11 +509,11 @@ template AttrClass getAttrOfType(StringAttr name) { - return getAttr(name).dyn_cast_or_null(); + return llvm::dyn_cast_or_null(getAttr(name)); } template AttrClass getAttrOfType(StringRef name) { - return getAttr(name).dyn_cast_or_null(); + return llvm::dyn_cast_or_null(getAttr(name)); } /// Return true if the operation has an attribute with the provided name, diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -433,7 +433,7 @@ static bool classof(Value value) { return llvm::isa(value.getType()); } /// Return the known Type - Ty getType() { return Value::getType().template cast(); } + Ty getType() { return llvm::cast(Value::getType()); } void setType(Ty ty) { Value::setType(ty); } }; diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -25,7 +25,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsALocation(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } //===----------------------------------------------------------------------===// @@ -33,7 +33,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAAffineMap(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) { @@ -41,7 +41,7 @@ } MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getValue()); + return wrap(llvm::cast(unwrap(attr)).getValue()); } //===----------------------------------------------------------------------===// @@ -49,7 +49,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAArray(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, @@ -61,11 +61,11 @@ } intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) { - return static_cast(unwrap(attr).cast().size()); + return static_cast(llvm::cast(unwrap(attr)).size()); } MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) { - return wrap(unwrap(attr).cast().getValue()[pos]); + return wrap(llvm::cast(unwrap(attr)).getValue()[pos]); } //===----------------------------------------------------------------------===// @@ -73,7 +73,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsADictionary(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, @@ -87,19 +87,19 @@ } intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) { - return static_cast(unwrap(attr).cast().size()); + return static_cast(llvm::cast(unwrap(attr)).size()); } MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr, intptr_t pos) { NamedAttribute attribute = - unwrap(attr).cast().getValue()[pos]; + llvm::cast(unwrap(attr)).getValue()[pos]; return {wrap(attribute.getName()), wrap(attribute.getValue())}; } MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, MlirStringRef name) { - return wrap(unwrap(attr).cast().get(unwrap(name))); + return wrap(llvm::cast(unwrap(attr)).get(unwrap(name))); } //===----------------------------------------------------------------------===// @@ -107,7 +107,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAFloat(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type, @@ -121,7 +121,7 @@ } double mlirFloatAttrGetValueDouble(MlirAttribute attr) { - return unwrap(attr).cast().getValueAsDouble(); + return llvm::cast(unwrap(attr)).getValueAsDouble(); } //===----------------------------------------------------------------------===// @@ -129,7 +129,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAInteger(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) { @@ -137,15 +137,15 @@ } int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) { - return unwrap(attr).cast().getInt(); + return llvm::cast(unwrap(attr)).getInt(); } int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr) { - return unwrap(attr).cast().getSInt(); + return llvm::cast(unwrap(attr)).getSInt(); } uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) { - return unwrap(attr).cast().getUInt(); + return llvm::cast(unwrap(attr)).getUInt(); } //===----------------------------------------------------------------------===// @@ -153,7 +153,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsABool(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) { @@ -161,7 +161,7 @@ } bool mlirBoolAttrGetValue(MlirAttribute attr) { - return unwrap(attr).cast().getValue(); + return llvm::cast(unwrap(attr)).getValue(); } //===----------------------------------------------------------------------===// @@ -169,7 +169,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAIntegerSet(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } //===----------------------------------------------------------------------===// @@ -177,7 +177,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAOpaque(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, @@ -189,11 +189,12 @@ } MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getDialectNamespace().strref()); + return wrap( + llvm::cast(unwrap(attr)).getDialectNamespace().strref()); } MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getAttrData()); + return wrap(llvm::cast(unwrap(attr)).getAttrData()); } //===----------------------------------------------------------------------===// @@ -201,7 +202,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAString(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) { @@ -213,7 +214,7 @@ } MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getValue()); + return wrap(llvm::cast(unwrap(attr)).getValue()); } //===----------------------------------------------------------------------===// @@ -221,7 +222,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsASymbolRef(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol, @@ -230,27 +231,30 @@ SmallVector refs; refs.reserve(numReferences); for (intptr_t i = 0; i < numReferences; ++i) - refs.push_back(unwrap(references[i]).cast()); + refs.push_back(llvm::cast(unwrap(references[i]))); auto symbolAttr = StringAttr::get(unwrap(ctx), unwrap(symbol)); return wrap(SymbolRefAttr::get(symbolAttr, refs)); } MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getRootReference().getValue()); + return wrap( + llvm::cast(unwrap(attr)).getRootReference().getValue()); } MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getLeafReference().getValue()); + return wrap( + llvm::cast(unwrap(attr)).getLeafReference().getValue()); } intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) { return static_cast( - unwrap(attr).cast().getNestedReferences().size()); + llvm::cast(unwrap(attr)).getNestedReferences().size()); } MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, intptr_t pos) { - return wrap(unwrap(attr).cast().getNestedReferences()[pos]); + return wrap( + llvm::cast(unwrap(attr)).getNestedReferences()[pos]); } //===----------------------------------------------------------------------===// @@ -258,7 +262,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) { @@ -266,7 +270,7 @@ } MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getValue()); + return wrap(llvm::cast(unwrap(attr)).getValue()); } //===----------------------------------------------------------------------===// @@ -274,7 +278,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAType(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirTypeAttrGet(MlirType type) { @@ -282,7 +286,7 @@ } MlirType mlirTypeAttrGetValue(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getValue()); + return wrap(llvm::cast(unwrap(attr)).getValue()); } //===----------------------------------------------------------------------===// @@ -290,7 +294,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAUnit(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirUnitAttrGet(MlirContext ctx) { @@ -302,24 +306,23 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAElements(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank, uint64_t *idxs) { - return wrap(unwrap(attr) - .cast() + return wrap(llvm::cast(unwrap(attr)) .getValues()[llvm::ArrayRef(idxs, rank)]); } bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, uint64_t *idxs) { - return unwrap(attr).cast().isValidIndex( - llvm::ArrayRef(idxs, rank)); + return llvm::cast(unwrap(attr)) + .isValidIndex(llvm::ArrayRef(idxs, rank)); } int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) { - return unwrap(attr).cast().getNumElements(); + return llvm::cast(unwrap(attr)).getNumElements(); } //===----------------------------------------------------------------------===// @@ -330,25 +333,25 @@ // IsA support. bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseI8Array(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseI16Array(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseI32Array(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseI64Array(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseF32Array(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseF64Array(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } //===----------------------------------------------------------------------===// @@ -394,32 +397,32 @@ // Accessors. intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) { - return unwrap(attr).cast().size(); + return llvm::cast(unwrap(attr)).size(); } //===----------------------------------------------------------------------===// // Indexed accessors. bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast()[pos]; + return llvm::cast(unwrap(attr))[pos]; } int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast()[pos]; + return llvm::cast(unwrap(attr))[pos]; } int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast()[pos]; + return llvm::cast(unwrap(attr))[pos]; } int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast()[pos]; + return llvm::cast(unwrap(attr))[pos]; } int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast()[pos]; + return llvm::cast(unwrap(attr))[pos]; } float mlirDenseF32ArrayGetElement(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast()[pos]; + return llvm::cast(unwrap(attr))[pos]; } double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast()[pos]; + return llvm::cast(unwrap(attr))[pos]; } //===----------------------------------------------------------------------===// @@ -430,13 +433,13 @@ // IsA support. bool mlirAttributeIsADenseElements(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseIntElements(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseFPElements(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } //===----------------------------------------------------------------------===// @@ -447,14 +450,14 @@ MlirAttribute const *elements) { SmallVector attributes; return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), + DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), unwrapList(numElements, elements, attributes))); } MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType, size_t rawBufferSize, const void *rawBuffer) { - auto shapedTypeCpp = unwrap(shapedType).cast(); + auto shapedTypeCpp = llvm::cast(unwrap(shapedType)); ArrayRef rawBufferCpp(static_cast(rawBuffer), rawBufferSize); bool isSplat = false; @@ -466,61 +469,61 @@ MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, MlirAttribute element) { - return wrap(DenseElementsAttr::get(unwrap(shapedType).cast(), + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), unwrap(element))); } MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType, bool element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType, uint8_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrInt8SplatGet(MlirType shapedType, int8_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType, uint32_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType, int32_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType, uint64_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType, int64_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType, float element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType, double element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType, intptr_t numElements, const int *elements) { SmallVector values(elements, elements + numElements); - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), values)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + values)); } /// Creates a dense attribute with elements of the type deduced by templates. @@ -528,7 +531,7 @@ static MlirAttribute getDenseAttribute(MlirType shapedType, intptr_t numElements, const T *elements) { - return wrap(DenseElementsAttr::get(unwrap(shapedType).cast(), + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), llvm::ArrayRef(elements, numElements))); } @@ -605,99 +608,99 @@ for (intptr_t i = 0; i < numElements; ++i) values.push_back(unwrap(strs[i])); - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), values)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + values)); } MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr, MlirType shapedType) { - return wrap(unwrap(attr).cast().reshape( - unwrap(shapedType).cast())); + return wrap(llvm::cast(unwrap(attr)) + .reshape(llvm::cast(unwrap(shapedType)))); } //===----------------------------------------------------------------------===// // Splat accessors. bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) { - return unwrap(attr).cast().isSplat(); + return llvm::cast(unwrap(attr)).isSplat(); } MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) { return wrap( - unwrap(attr).cast().getSplatValue()); + llvm::cast(unwrap(attr)).getSplatValue()); } int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } int8_t mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } uint8_t mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) { return wrap( - unwrap(attr).cast().getSplatValue()); + llvm::cast(unwrap(attr)).getSplatValue()); } //===----------------------------------------------------------------------===// // Indexed accessors. bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos) { return wrap( - unwrap(attr).cast().getValues()[pos]); + llvm::cast(unwrap(attr)).getValues()[pos]); } //===----------------------------------------------------------------------===// @@ -705,7 +708,7 @@ const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { return static_cast( - unwrap(attr).cast().getRawData().data()); + llvm::cast(unwrap(attr)).getRawData().data()); } //===----------------------------------------------------------------------===// @@ -715,7 +718,7 @@ template static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name, intptr_t numElements, const T *elements) { - return wrap(U::get(unwrap(shapedType).cast(), unwrap(name), + return wrap(U::get(llvm::cast(unwrap(shapedType)), unwrap(name), UnmanagedAsmResourceBlob::allocateInferAlign( llvm::ArrayRef(elements, numElements)))); } @@ -797,7 +800,7 @@ template static T getDenseResourceVal(MlirAttribute attr, intptr_t pos) { - return (*unwrap(attr).cast().tryGetAsArrayRef())[pos]; + return (*llvm::cast(unwrap(attr)).tryGetAsArrayRef())[pos]; } MLIR_CAPI_EXPORTED bool @@ -853,24 +856,24 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsASparseElements(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirSparseElementsAttribute(MlirType shapedType, MlirAttribute denseIndices, MlirAttribute denseValues) { - return wrap( - SparseElementsAttr::get(unwrap(shapedType).cast(), - unwrap(denseIndices).cast(), - unwrap(denseValues).cast())); + return wrap(SparseElementsAttr::get( + llvm::cast(unwrap(shapedType)), + llvm::cast(unwrap(denseIndices)), + llvm::cast(unwrap(denseValues)))); } MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getIndices()); + return wrap(llvm::cast(unwrap(attr)).getIndices()); } MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getValues()); + return wrap(llvm::cast(unwrap(attr)).getValues()); } //===----------------------------------------------------------------------===// @@ -878,7 +881,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAStridedLayout(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, @@ -889,14 +892,14 @@ } int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr) { - return unwrap(attr).cast().getOffset(); + return llvm::cast(unwrap(attr)).getOffset(); } intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr) { return static_cast( - unwrap(attr).cast().getStrides().size()); + llvm::cast(unwrap(attr)).getStrides().size()); } int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getStrides()[pos]; + return llvm::cast(unwrap(attr)).getStrides()[pos]; } diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -23,7 +23,7 @@ //===----------------------------------------------------------------------===// bool mlirTypeIsAInteger(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) { @@ -39,26 +39,28 @@ } unsigned mlirIntegerTypeGetWidth(MlirType type) { - return unwrap(type).cast().getWidth(); + return llvm::cast(unwrap(type)).getWidth(); } bool mlirIntegerTypeIsSignless(MlirType type) { - return unwrap(type).cast().isSignless(); + return llvm::cast(unwrap(type)).isSignless(); } bool mlirIntegerTypeIsSigned(MlirType type) { - return unwrap(type).cast().isSigned(); + return llvm::cast(unwrap(type)).isSigned(); } bool mlirIntegerTypeIsUnsigned(MlirType type) { - return unwrap(type).cast().isUnsigned(); + return llvm::cast(unwrap(type)).isUnsigned(); } //===----------------------------------------------------------------------===// // Index type. //===----------------------------------------------------------------------===// -bool mlirTypeIsAIndex(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsAIndex(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirIndexTypeGet(MlirContext ctx) { return wrap(IndexType::get(unwrap(ctx))); @@ -136,7 +138,9 @@ // None type. //===----------------------------------------------------------------------===// -bool mlirTypeIsANone(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsANone(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirNoneTypeGet(MlirContext ctx) { return wrap(NoneType::get(unwrap(ctx))); @@ -147,7 +151,7 @@ //===----------------------------------------------------------------------===// bool mlirTypeIsAComplex(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirComplexTypeGet(MlirType elementType) { @@ -155,38 +159,41 @@ } MlirType mlirComplexTypeGetElementType(MlirType type) { - return wrap(unwrap(type).cast().getElementType()); + return wrap(llvm::cast(unwrap(type)).getElementType()); } //===----------------------------------------------------------------------===// // Shaped type. //===----------------------------------------------------------------------===// -bool mlirTypeIsAShaped(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsAShaped(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirShapedTypeGetElementType(MlirType type) { - return wrap(unwrap(type).cast().getElementType()); + return wrap(llvm::cast(unwrap(type)).getElementType()); } bool mlirShapedTypeHasRank(MlirType type) { - return unwrap(type).cast().hasRank(); + return llvm::cast(unwrap(type)).hasRank(); } int64_t mlirShapedTypeGetRank(MlirType type) { - return unwrap(type).cast().getRank(); + return llvm::cast(unwrap(type)).getRank(); } bool mlirShapedTypeHasStaticShape(MlirType type) { - return unwrap(type).cast().hasStaticShape(); + return llvm::cast(unwrap(type)).hasStaticShape(); } bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) { - return unwrap(type).cast().isDynamicDim( - static_cast(dim)); + return llvm::cast(unwrap(type)) + .isDynamicDim(static_cast(dim)); } int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) { - return unwrap(type).cast().getDimSize(static_cast(dim)); + return llvm::cast(unwrap(type)) + .getDimSize(static_cast(dim)); } int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamic; } @@ -207,7 +214,9 @@ // Vector type. //===----------------------------------------------------------------------===// -bool mlirTypeIsAVector(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsAVector(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape, MlirType elementType) { @@ -226,14 +235,16 @@ // Ranked / Unranked tensor type. //===----------------------------------------------------------------------===// -bool mlirTypeIsATensor(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsATensor(MlirType type) { + return llvm::isa(unwrap(type)); +} bool mlirTypeIsARankedTensor(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } bool mlirTypeIsAUnrankedTensor(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape, @@ -253,7 +264,7 @@ } MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type) { - return wrap(unwrap(type).cast().getEncoding()); + return wrap(llvm::cast(unwrap(type)).getEncoding()); } MlirType mlirUnrankedTensorTypeGet(MlirType elementType) { @@ -269,7 +280,9 @@ // Ranked / Unranked MemRef type. //===----------------------------------------------------------------------===// -bool mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsAMemRef(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute layout, @@ -278,7 +291,7 @@ llvm::ArrayRef(shape, static_cast(rank)), unwrap(elementType), mlirAttributeIsNull(layout) ? MemRefLayoutAttrInterface() - : unwrap(layout).cast(), + : llvm::cast(unwrap(layout)), unwrap(memorySpace))); } @@ -291,7 +304,7 @@ unwrap(elementType), mlirAttributeIsNull(layout) ? MemRefLayoutAttrInterface() - : unwrap(layout).cast(), + : llvm::cast(unwrap(layout)), unwrap(memorySpace))); } @@ -313,19 +326,19 @@ } MlirAttribute mlirMemRefTypeGetLayout(MlirType type) { - return wrap(unwrap(type).cast().getLayout()); + return wrap(llvm::cast(unwrap(type)).getLayout()); } MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type) { - return wrap(unwrap(type).cast().getLayout().getAffineMap()); + return wrap(llvm::cast(unwrap(type)).getLayout().getAffineMap()); } MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) { - return wrap(unwrap(type).cast().getMemorySpace()); + return wrap(llvm::cast(unwrap(type)).getMemorySpace()); } bool mlirTypeIsAUnrankedMemRef(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, @@ -342,14 +355,16 @@ } MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) { - return wrap(unwrap(type).cast().getMemorySpace()); + return wrap(llvm::cast(unwrap(type)).getMemorySpace()); } //===----------------------------------------------------------------------===// // Tuple type. //===----------------------------------------------------------------------===// -bool mlirTypeIsATuple(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsATuple(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements, MlirType const *elements) { @@ -359,11 +374,12 @@ } intptr_t mlirTupleTypeGetNumTypes(MlirType type) { - return unwrap(type).cast().size(); + return llvm::cast(unwrap(type)).size(); } MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) { - return wrap(unwrap(type).cast().getType(static_cast(pos))); + return wrap( + llvm::cast(unwrap(type)).getType(static_cast(pos))); } //===----------------------------------------------------------------------===// @@ -371,7 +387,7 @@ //===----------------------------------------------------------------------===// bool mlirTypeIsAFunction(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs, @@ -385,30 +401,32 @@ } intptr_t mlirFunctionTypeGetNumInputs(MlirType type) { - return unwrap(type).cast().getNumInputs(); + return llvm::cast(unwrap(type)).getNumInputs(); } intptr_t mlirFunctionTypeGetNumResults(MlirType type) { - return unwrap(type).cast().getNumResults(); + return llvm::cast(unwrap(type)).getNumResults(); } MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos) { assert(pos >= 0 && "pos in array must be positive"); - return wrap( - unwrap(type).cast().getInput(static_cast(pos))); + return wrap(llvm::cast(unwrap(type)) + .getInput(static_cast(pos))); } MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) { assert(pos >= 0 && "pos in array must be positive"); - return wrap( - unwrap(type).cast().getResult(static_cast(pos))); + return wrap(llvm::cast(unwrap(type)) + .getResult(static_cast(pos))); } //===----------------------------------------------------------------------===// // Opaque type. //===----------------------------------------------------------------------===// -bool mlirTypeIsAOpaque(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsAOpaque(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace, MlirStringRef typeData) { @@ -418,9 +436,10 @@ } MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type) { - return wrap(unwrap(type).cast().getDialectNamespace().strref()); + return wrap( + llvm::cast(unwrap(type)).getDialectNamespace().strref()); } MlirStringRef mlirOpaqueTypeGetData(MlirType type) { - return wrap(unwrap(type).cast().getTypeData()); + return wrap(llvm::cast(unwrap(type)).getTypeData()); } diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -172,7 +172,7 @@ } MlirLocation mlirLocationFromAttribute(MlirAttribute attribute) { - return wrap(Location(unwrap(attribute).cast())); + return wrap(Location(llvm::cast(unwrap(attribute)))); } MlirLocation mlirLocationFileLineColGet(MlirContext context, @@ -727,33 +727,33 @@ } bool mlirValueIsABlockArgument(MlirValue value) { - return unwrap(value).isa(); + return llvm::isa(unwrap(value)); } bool mlirValueIsAOpResult(MlirValue value) { - return unwrap(value).isa(); + return llvm::isa(unwrap(value)); } MlirBlock mlirBlockArgumentGetOwner(MlirValue value) { - return wrap(unwrap(value).cast().getOwner()); + return wrap(llvm::cast(unwrap(value)).getOwner()); } intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) { return static_cast( - unwrap(value).cast().getArgNumber()); + llvm::cast(unwrap(value)).getArgNumber()); } void mlirBlockArgumentSetType(MlirValue value, MlirType type) { - unwrap(value).cast().setType(unwrap(type)); + llvm::cast(unwrap(value)).setType(unwrap(type)); } MlirOperation mlirOpResultGetOwner(MlirValue value) { - return wrap(unwrap(value).cast().getOwner()); + return wrap(llvm::cast(unwrap(value)).getOwner()); } intptr_t mlirOpResultGetResultNumber(MlirValue value) { return static_cast( - unwrap(value).cast().getResultNumber()); + llvm::cast(unwrap(value)).getResultNumber()); } MlirType mlirValueGetType(MlirValue value) { @@ -857,7 +857,7 @@ MlirType mlirAttributeGetType(MlirAttribute attribute) { Attribute attr = unwrap(attribute); - if (auto typedAttr = attr.dyn_cast()) + if (auto typedAttr = llvm::dyn_cast(attr)) return wrap(typedAttr.getType()); return wrap(NoneType::get(attr.getContext())); } diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -48,14 +48,15 @@ //===----------------------------------------------------------------------===// template static LogicalResult verifyRawBufferOp(T &op) { - MemRefType bufferType = op.getMemref().getType().template cast(); + MemRefType bufferType = llvm::cast(op.getMemref().getType()); Attribute memorySpace = bufferType.getMemorySpace(); bool isGlobal = false; if (!memorySpace) isGlobal = true; - else if (auto intMemorySpace = memorySpace.dyn_cast()) + else if (auto intMemorySpace = llvm::dyn_cast(memorySpace)) isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1; - else if (auto gpuMemorySpace = memorySpace.dyn_cast()) + else if (auto gpuMemorySpace = + llvm::dyn_cast(memorySpace)) isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global; if (!isGlobal) @@ -216,11 +217,11 @@ Type sourceElem = sourceType, destElem = destType; uint32_t sourceLen = 1, destLen = 1; - if (auto sourceVector = sourceType.dyn_cast()) { + if (auto sourceVector = llvm::dyn_cast(sourceType)) { sourceLen = sourceVector.getNumElements(); sourceElem = sourceVector.getElementType(); } - if (auto destVector = destType.dyn_cast()) { + if (auto destVector = llvm::dyn_cast(destType)) { destLen = destVector.getNumElements(); destElem = destVector.getElementType(); } @@ -229,7 +230,7 @@ if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) { int64_t sourceBLen = 1; Type sourceBElem = sourceBType; - if (auto sourceBVector = sourceBType.dyn_cast()) { + if (auto sourceBVector = llvm::dyn_cast(sourceBType)) { sourceBLen = sourceBVector.getNumElements(); sourceBElem = sourceBVector.getElementType(); } diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -38,7 +38,7 @@ /// top level of a `AffineScope` region is always a valid symbol for all /// uses in that region. bool mlir::affine::isTopLevelValue(Value value, Region *region) { - if (auto arg = value.dyn_cast()) + if (auto arg = llvm::dyn_cast(value)) return arg.getParentRegion() == region; return value.getDefiningOp()->getParentRegion() == region; } @@ -62,7 +62,7 @@ // If it's a top-level value because it's a block operand, i.e. a // function argument, check whether the value replacing it after // inlining is a valid dimension in the new region. - if (value.isa()) + if (llvm::isa(value)) return legalityCheck(mapping.lookup(value), dest); // If it's a top-level value because it's defined in the region, @@ -234,7 +234,7 @@ /// conservatively assume it is not top-level. A value of index type defined at /// the top level is always a valid symbol. bool mlir::affine::isTopLevelValue(Value value) { - if (auto arg = value.dyn_cast()) { + if (auto arg = llvm::dyn_cast(value)) { // The block owning the argument may be unlinked, e.g. when the surrounding // region has not yet been attached to an Op, at which point the parent Op // is null. @@ -273,7 +273,7 @@ // This value has to be a block argument for an op that has the // `AffineScope` trait or for an affine.for or affine.parallel. - auto *parentOp = value.cast().getOwner()->getParentOp(); + auto *parentOp = llvm::cast(value).getOwner()->getParentOp(); return parentOp && (parentOp->hasTrait() || isa(parentOp)); } @@ -296,7 +296,7 @@ if (!op) { // This value has to be a block argument for an affine.for or an // affine.parallel. - auto *parentOp = value.cast().getOwner()->getParentOp(); + auto *parentOp = llvm::cast(value).getOwner()->getParentOp(); return isa(parentOp); } @@ -334,7 +334,7 @@ // Conservatively handle remaining BlockArguments as non-valid symbols. // E.g. scf.for iterArgs. - if (dimOp.getShapedValue().template isa()) + if (llvm::isa(dimOp.getShapedValue())) return false; // The dim op is also okay if its operand memref is a view/subview whose @@ -1221,7 +1221,8 @@ // AffineDialect materializer will create invalid `arith.constant` // operations if the provided Attribute is any other kind of integer. constants.push_back(dialect->materializeConstant( - b, b.getIndexAttr(ofr.get().cast().getInt()), + b, + b.getIndexAttr(llvm::cast(ofr.get()).getInt()), b.getIndexType(), loc)); actualValues.push_back(constants.back()->getResult(0)); } @@ -1785,11 +1786,11 @@ } LogicalResult AffineDmaStartOp::verifyInvariantsImpl() { - if (!getOperand(getSrcMemRefOperandIndex()).getType().isa()) + if (!llvm::isa(getOperand(getSrcMemRefOperandIndex()).getType())) return emitOpError("expected DMA source to be of memref type"); - if (!getOperand(getDstMemRefOperandIndex()).getType().isa()) + if (!llvm::isa(getOperand(getDstMemRefOperandIndex()).getType())) return emitOpError("expected DMA destination to be of memref type"); - if (!getOperand(getTagMemRefOperandIndex()).getType().isa()) + if (!llvm::isa(getOperand(getTagMemRefOperandIndex()).getType())) return emitOpError("expected DMA tag to be of memref type"); unsigned numInputsAllMaps = getSrcMap().getNumInputs() + @@ -1888,7 +1889,7 @@ parser.resolveOperand(numElementsInfo, indexType, result.operands)) return failure(); - if (!type.isa()) + if (!llvm::isa(type)) return parser.emitError(parser.getNameLoc(), "expected tag to be of memref type"); @@ -1899,7 +1900,7 @@ } LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() { - if (!getOperand(0).getType().isa()) + if (!llvm::isa(getOperand(0).getType())) return emitOpError("expected DMA tag to be of memref type"); Region *scope = getAffineScope(*this); for (auto idx : getTagIndices()) { @@ -2073,7 +2074,7 @@ return failure(); // Parse full form - affine map followed by dim and symbol list. - if (auto affineMapAttr = boundAttr.dyn_cast()) { + if (auto affineMapAttr = llvm::dyn_cast(boundAttr)) { unsigned currentNumOperands = result.operands.size(); unsigned numDims; if (parseDimAndSymbolList(p, result.operands, numDims)) @@ -2106,7 +2107,7 @@ } // Parse custom assembly form. - if (auto integerAttr = boundAttr.dyn_cast()) { + if (auto integerAttr = llvm::dyn_cast(boundAttr)) { result.attributes.pop_back(); result.addAttribute( boundAttrStrName, @@ -2296,9 +2297,9 @@ // Compute the max or min as applicable over the results. assert(!foldedResults.empty() && "bounds should have at least one result"); - auto maxOrMin = foldedResults[0].cast().getValue(); + auto maxOrMin = llvm::cast(foldedResults[0]).getValue(); for (unsigned i = 1, e = foldedResults.size(); i < e; i++) { - auto foldedResult = foldedResults[i].cast().getValue(); + auto foldedResult = llvm::cast(foldedResults[i]).getValue(); maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult) : llvm::APIntOps::smin(maxOrMin, foldedResult); } @@ -2653,7 +2654,7 @@ } AffineForOp mlir::affine::getForInductionVarOwner(Value val) { - auto ivArg = val.dyn_cast(); + auto ivArg = llvm::dyn_cast(val); if (!ivArg || !ivArg.getOwner()) return AffineForOp(); auto *containingInst = ivArg.getOwner()->getParent()->getParentOp(); @@ -2664,7 +2665,7 @@ } AffineParallelOp mlir::affine::getAffineParallelInductionVarOwner(Value val) { - auto ivArg = val.dyn_cast(); + auto ivArg = llvm::dyn_cast(val); if (!ivArg || !ivArg.getOwner()) return nullptr; Operation *containingOp = ivArg.getOwner()->getParentOp(); @@ -3113,7 +3114,7 @@ result.addOperands(operands); if (map) result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map)); - auto memrefType = operands[0].getType().cast(); + auto memrefType = llvm::cast(operands[0].getType()); result.types.push_back(memrefType.getElementType()); } @@ -3122,14 +3123,14 @@ assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); result.addOperands(memref); result.addOperands(mapOperands); - auto memrefType = memref.getType().cast(); + auto memrefType = llvm::cast(memref.getType()); result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map)); result.types.push_back(memrefType.getElementType()); } void AffineLoadOp::build(OpBuilder &builder, OperationState &result, Value memref, ValueRange indices) { - auto memrefType = memref.getType().cast(); + auto memrefType = llvm::cast(memref.getType()); int64_t rank = memrefType.getRank(); // Create identity map for memrefs with at least one dimension or () -> () // for zero-dimensional memrefs. @@ -3238,11 +3239,11 @@ // Check if the global memref is a constant. auto cstAttr = - global.getConstantInitValue().dyn_cast_or_null(); + llvm::dyn_cast_or_null(global.getConstantInitValue()); if (!cstAttr) return {}; // If it's a splat constant, we can fold irrespective of indices. - if (auto splatAttr = cstAttr.dyn_cast()) + if (auto splatAttr = llvm::dyn_cast(cstAttr)) return splatAttr.getSplatValue(); // Otherwise, we can fold only if we know the indices. if (!getAffineMap().isConstant()) @@ -3271,7 +3272,7 @@ void AffineStoreOp::build(OpBuilder &builder, OperationState &result, Value valueToStore, Value memref, ValueRange indices) { - auto memrefType = memref.getType().cast(); + auto memrefType = llvm::cast(memref.getType()); int64_t rank = memrefType.getRank(); // Create identity map for memrefs with at least one dimension or () -> () // for zero-dimensional memrefs. @@ -4017,7 +4018,7 @@ // Verify reduction ops are all valid for (Attribute attr : getReductions()) { - auto intAttr = attr.dyn_cast(); + auto intAttr = llvm::dyn_cast(attr); if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt())) return emitOpError("invalid reduction attribute"); } @@ -4119,7 +4120,7 @@ p << " reduce ("; llvm::interleaveComma(getReductions(), p, [&](auto &attr) { arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind( - attr.template cast().getInt()); + llvm::cast(attr).getInt()); p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\""; }); p << ") -> (" << getResultTypes() << ")"; @@ -4429,7 +4430,7 @@ void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result, VectorType resultType, Value memref, ValueRange indices) { - auto memrefType = memref.getType().cast(); + auto memrefType = llvm::cast(memref.getType()); int64_t rank = memrefType.getRank(); // Create identity map for memrefs with at least one dimension or () -> () // for zero-dimensional memrefs. @@ -4520,7 +4521,7 @@ void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result, Value valueToStore, Value memref, ValueRange indices) { - auto memrefType = memref.getType().cast(); + auto memrefType = llvm::cast(memref.getType()); int64_t rank = memrefType.getRank(); // Create identity map for memrefs with at least one dimension or () -> () // for zero-dimensional memrefs. diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -36,15 +36,15 @@ static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { return builder.getIntegerAttr(res.getType(), - lhs.cast().getInt() + - rhs.cast().getInt()); + llvm::cast(lhs).getInt() + + llvm::cast(rhs).getInt()); } static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { return builder.getIntegerAttr(res.getType(), - lhs.cast().getInt() - - rhs.cast().getInt()); + llvm::cast(lhs).getInt() - + llvm::cast(rhs).getInt()); } /// Invert an integer comparison predicate. @@ -92,11 +92,11 @@ } static FailureOr getIntOrSplatIntValue(Attribute attr) { - if (auto intAttr = attr.dyn_cast()) + if (auto intAttr = llvm::dyn_cast(attr)) return intAttr.getValue(); - if (auto splatAttr = attr.dyn_cast()) - if (splatAttr.getElementType().isa()) + if (auto splatAttr = llvm::dyn_cast(attr)) + if (llvm::isa(splatAttr.getElementType())) return splatAttr.getSplatValue(); return failure(); @@ -117,11 +117,11 @@ /// Return the type of the same shape (scalar, vector or tensor) containing i1. static Type getI1SameShape(Type type) { auto i1Type = IntegerType::get(type.getContext(), 1); - if (auto tensorType = type.dyn_cast()) + if (auto tensorType = llvm::dyn_cast(type)) return RankedTensorType::get(tensorType.getShape(), i1Type); - if (type.isa()) + if (llvm::isa(type)) return UnrankedTensorType::get(i1Type); - if (auto vectorType = type.dyn_cast()) + if (auto vectorType = llvm::dyn_cast(type)) return VectorType::get(vectorType.getShape(), i1Type, vectorType.getNumScalableDims()); return i1Type; @@ -134,8 +134,8 @@ void arith::ConstantOp::getAsmResultNames( function_ref setNameFn) { auto type = getType(); - if (auto intCst = getValue().dyn_cast()) { - auto intType = type.dyn_cast(); + if (auto intCst = llvm::dyn_cast(getValue())) { + auto intType = llvm::dyn_cast(type); // Sugar i1 constants with 'true' and 'false'. if (intType && intType.getWidth() == 1) @@ -163,10 +163,11 @@ << " must match return type: " << type; } // Integer values must be signless. - if (type.isa() && !type.cast().isSignless()) + if (llvm::isa(type) && + !llvm::cast(type).isSignless()) return emitOpError("integer return type must be signless"); // Any float or elements attribute are acceptable. - if (!getValue().isa()) { + if (!llvm::isa(getValue())) { return emitOpError( "value must be an integer, float, or elements attribute"); } @@ -175,14 +176,15 @@ bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { // The value's type must be the same as the provided type. - auto typedAttr = value.dyn_cast(); + auto typedAttr = llvm::dyn_cast(value); if (!typedAttr || typedAttr.getType() != type) return false; // Integer values must be signless. - if (type.isa() && !type.cast().isSignless()) + if (llvm::isa(type) && + !llvm::cast(type).isSignless()) return false; // Integer, float, and element attributes are buildable. - return value.isa(); + return llvm::isa(value); } ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value, @@ -223,7 +225,7 @@ bool arith::ConstantFloatOp::classof(Operation *op) { if (auto constOp = dyn_cast_or_null(op)) - return constOp.getType().isa(); + return llvm::isa(constOp.getType()); return false; } @@ -275,7 +277,7 @@ std::optional> arith::AddUIExtendedOp::getShapeForUnroll() { - if (auto vt = getType(0).dyn_cast()) + if (auto vt = llvm::dyn_cast(getType(0))) return llvm::to_vector<4>(vt.getShape()); return std::nullopt; } @@ -309,7 +311,7 @@ [](APInt a, const APInt &b) { return std::move(a) + b; })) { Attribute overflowAttr = constFoldBinaryOp( ArrayRef({sumAttr, adaptor.getLhs()}), - getI1SameShape(sumAttr.cast().getType()), + getI1SameShape(llvm::cast(sumAttr).getType()), calculateUnsignedOverflow); if (!overflowAttr) return failure(); @@ -385,7 +387,7 @@ std::optional> arith::MulSIExtendedOp::getShapeForUnroll() { - if (auto vt = getType(0).dyn_cast()) + if (auto vt = llvm::dyn_cast(getType(0))) return llvm::to_vector<4>(vt.getShape()); return std::nullopt; } @@ -433,7 +435,7 @@ std::optional> arith::MulUIExtendedOp::getShapeForUnroll() { - if (auto vt = getType(0).dyn_cast()) + if (auto vt = llvm::dyn_cast(getType(0))) return llvm::to_vector<4>(vt.getShape()); return std::nullopt; } @@ -1093,11 +1095,11 @@ template static Type getUnderlyingType(Type type, type_list, type_list) { - if (type.isa() && !type.isa()) + if (llvm::isa(type) && !llvm::isa(type)) return {}; auto underlyingType = getElementTypeOrSelf(type); - if (!underlyingType.isa()) + if (!llvm::isa(underlyingType)) return {}; return underlyingType; @@ -1133,7 +1135,8 @@ Type srcType = getElementTypeOrSelf(op.getIn().getType()); Type dstType = getElementTypeOrSelf(op.getType()); - if (srcType.cast().getWidth() >= dstType.cast().getWidth()) + if (llvm::cast(srcType).getWidth() >= + llvm::cast(dstType).getWidth()) return op.emitError("result type ") << dstType << " must be wider than operand type " << srcType; @@ -1146,7 +1149,8 @@ Type srcType = getElementTypeOrSelf(op.getIn().getType()); Type dstType = getElementTypeOrSelf(op.getType()); - if (srcType.cast().getWidth() <= dstType.cast().getWidth()) + if (llvm::cast(srcType).getWidth() <= + llvm::cast(dstType).getWidth()) return op.emitError("result type ") << dstType << " must be shorter than operand type " << srcType; @@ -1179,7 +1183,7 @@ } Type resType = getElementTypeOrSelf(getType()); - unsigned bitWidth = resType.cast().getWidth(); + unsigned bitWidth = llvm::cast(resType).getWidth(); return constFoldCastOp( adaptor.getOperands(), getType(), [bitWidth](const APInt &a, bool &castStatus) { @@ -1206,7 +1210,7 @@ } Type resType = getElementTypeOrSelf(getType()); - unsigned bitWidth = resType.cast().getWidth(); + unsigned bitWidth = llvm::cast(resType).getWidth(); return constFoldCastOp( adaptor.getOperands(), getType(), [bitWidth](const APInt &a, bool &castStatus) { @@ -1259,8 +1263,8 @@ Type dstType = getElementTypeOrSelf(getType()); // trunci(zexti(a)) -> trunci(a) // trunci(sexti(a)) -> trunci(a) - if (srcType.cast().getWidth() > - dstType.cast().getWidth()) { + if (llvm::cast(srcType).getWidth() > + llvm::cast(dstType).getWidth()) { setOperand(src); return getResult(); } @@ -1276,7 +1280,7 @@ } Type resType = getElementTypeOrSelf(getType()); - unsigned bitWidth = resType.cast().getWidth(); + unsigned bitWidth = llvm::cast(resType).getWidth(); return constFoldCastOp( adaptor.getOperands(), getType(), [bitWidth](const APInt &a, bool &castStatus) { @@ -1307,12 +1311,12 @@ /// can be represented without precision loss or rounding. OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) { auto constOperand = adaptor.getIn(); - if (!constOperand || !constOperand.isa()) + if (!constOperand || !llvm::isa(constOperand)) return {}; // Convert to target type via 'double'. double sourceValue = - constOperand.dyn_cast().getValue().convertToDouble(); + llvm::dyn_cast(constOperand).getValue().convertToDouble(); auto targetAttr = FloatAttr::get(getType(), sourceValue); // Propagate if constant's value does not change after truncation. @@ -1376,7 +1380,7 @@ return constFoldCastOp( adaptor.getOperands(), getType(), [&resEleType](const APInt &a, bool &castStatus) { - FloatType floatTy = resEleType.cast(); + FloatType floatTy = llvm::cast(resEleType); APFloat apf(floatTy.getFloatSemantics(), APInt::getZero(floatTy.getWidth())); apf.convertFromAPInt(a, /*IsSigned=*/false, @@ -1398,7 +1402,7 @@ return constFoldCastOp( adaptor.getOperands(), getType(), [&resEleType](const APInt &a, bool &castStatus) { - FloatType floatTy = resEleType.cast(); + FloatType floatTy = llvm::cast(resEleType); APFloat apf(floatTy.getFloatSemantics(), APInt::getZero(floatTy.getWidth())); apf.convertFromAPInt(a, /*IsSigned=*/true, @@ -1416,7 +1420,7 @@ OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) { Type resType = getElementTypeOrSelf(getType()); - unsigned bitWidth = resType.cast().getWidth(); + unsigned bitWidth = llvm::cast(resType).getWidth(); return constFoldCastOp( adaptor.getOperands(), getType(), [&bitWidth](const APFloat &a, bool &castStatus) { @@ -1438,7 +1442,7 @@ OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) { Type resType = getElementTypeOrSelf(getType()); - unsigned bitWidth = resType.cast().getWidth(); + unsigned bitWidth = llvm::cast(resType).getWidth(); return constFoldCastOp( adaptor.getOperands(), getType(), [&bitWidth](const APFloat &a, bool &castStatus) { @@ -1542,18 +1546,18 @@ return {}; /// Bitcast dense elements. - if (auto denseAttr = operand.dyn_cast_or_null()) - return denseAttr.bitcast(resType.cast().getElementType()); + if (auto denseAttr = llvm::dyn_cast_or_null(operand)) + return denseAttr.bitcast(llvm::cast(resType).getElementType()); /// Other shaped types unhandled. - if (resType.isa()) + if (llvm::isa(resType)) return {}; /// Bitcast integer or float to integer or float. - APInt bits = operand.isa() - ? operand.cast().getValue().bitcastToAPInt() - : operand.cast().getValue(); + APInt bits = llvm::isa(operand) + ? llvm::cast(operand).getValue().bitcastToAPInt() + : llvm::cast(operand).getValue(); - if (auto resFloatType = resType.dyn_cast()) + if (auto resFloatType = llvm::dyn_cast(resType)) return FloatAttr::get(resType, APFloat(resFloatType.getFloatSemantics(), bits)); return IntegerAttr::get(resType, bits); @@ -1618,18 +1622,18 @@ static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { auto boolAttr = BoolAttr::get(ctx, value); - ShapedType shapedType = type.dyn_cast_or_null(); + ShapedType shapedType = llvm::dyn_cast_or_null(type); if (!shapedType) return boolAttr; return DenseElementsAttr::get(shapedType, boolAttr); } static std::optional getIntegerWidth(Type t) { - if (auto intType = t.dyn_cast()) { + if (auto intType = llvm::dyn_cast(t)) { return intType.getWidth(); } - if (auto vectorIntType = t.dyn_cast()) { - return vectorIntType.getElementType().cast().getWidth(); + if (auto vectorIntType = llvm::dyn_cast(t)) { + return llvm::cast(vectorIntType.getElementType()).getWidth(); } return std::nullopt; } @@ -1817,7 +1821,7 @@ // Get the width of the mantissa. We don't want to hack on conversions that // might lose information from the integer, e.g. "i64 -> float" - FloatType floatTy = op.getRhs().getType().cast(); + FloatType floatTy = llvm::cast(op.getRhs().getType()); int mantissaWidth = floatTy.getFPMantissaWidth(); if (mantissaWidth <= 0) return failure(); @@ -1837,7 +1841,7 @@ // Check to see that the input is converted from an integer type that is // small enough that preserves all bits. - auto intTy = intVal.getType().cast(); + auto intTy = llvm::cast(intVal.getType()); auto intWidth = intTy.getWidth(); // Number of bits representing values, as opposed to the sign @@ -2103,7 +2107,7 @@ LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override { // Cannot extui i1 to i1, or i1 to f32 - if (!op.getType().isa() || op.getType().isInteger(1)) + if (!llvm::isa(op.getType()) || op.getType().isInteger(1)) return failure(); // select %x, c1, %c0 => extui %arg @@ -2230,7 +2234,8 @@ p << " " << getOperands(); p.printOptionalAttrDict((*this)->getAttrs()); p << " : "; - if (ShapedType condType = getCondition().getType().dyn_cast()) + if (ShapedType condType = + llvm::dyn_cast(getCondition().getType())) p << condType << ", "; p << getType(); } @@ -2243,7 +2248,7 @@ // If the result type is a vector or tensor, the type can be a mask with the // same elements. Type resultType = getType(); - if (!resultType.isa()) + if (!llvm::isa(resultType)) return emitOpError() << "expected condition to be a signless i1, but got " << conditionType; Type shapedConditionType = getI1SameShape(resultType); @@ -2320,7 +2325,7 @@ case AtomicRMWKind::maxf: return builder.getFloatAttr( resultType, - APFloat::getInf(resultType.cast().getFloatSemantics(), + APFloat::getInf(llvm::cast(resultType).getFloatSemantics(), /*Negative=*/true)); case AtomicRMWKind::addf: case AtomicRMWKind::addi: @@ -2330,24 +2335,24 @@ case AtomicRMWKind::andi: return builder.getIntegerAttr( resultType, - APInt::getAllOnes(resultType.cast().getWidth())); + APInt::getAllOnes(llvm::cast(resultType).getWidth())); case AtomicRMWKind::maxs: return builder.getIntegerAttr( - resultType, - APInt::getSignedMinValue(resultType.cast().getWidth())); + resultType, APInt::getSignedMinValue( + llvm::cast(resultType).getWidth())); case AtomicRMWKind::minf: return builder.getFloatAttr( resultType, - APFloat::getInf(resultType.cast().getFloatSemantics(), + APFloat::getInf(llvm::cast(resultType).getFloatSemantics(), /*Negative=*/false)); case AtomicRMWKind::mins: return builder.getIntegerAttr( - resultType, - APInt::getSignedMaxValue(resultType.cast().getWidth())); + resultType, APInt::getSignedMaxValue( + llvm::cast(resultType).getWidth())); case AtomicRMWKind::minu: return builder.getIntegerAttr( resultType, - APInt::getMaxValue(resultType.cast().getWidth())); + APInt::getMaxValue(llvm::cast(resultType).getWidth())); case AtomicRMWKind::muli: return builder.getIntegerAttr(resultType, 1); case AtomicRMWKind::mulf: diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp --- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp +++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp @@ -25,7 +25,7 @@ void arith::ConstantOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - auto constAttr = getValue().dyn_cast_or_null(); + auto constAttr = llvm::dyn_cast_or_null(getValue()); if (constAttr) { const APInt &value = constAttr.getValue(); setResultRange(getResult(), ConstantIntRanges::constant(value)); diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp @@ -37,7 +37,7 @@ auto constantOp = cast(op); assert(value == constantOp.getResult() && "invalid value"); - if (auto attr = constantOp.getValue().dyn_cast()) + if (auto attr = llvm::dyn_cast(constantOp.getValue())) cstr.bound(value) == attr.getInt(); } }; diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp --- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp +++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp @@ -28,7 +28,7 @@ /// Return the scalable vector of the same shape and containing i1. static Type getI1SameShape(Type type) { auto i1Type = IntegerType::get(type.getContext(), 1); - if (auto sVectorType = type.dyn_cast()) + if (auto sVectorType = llvm::dyn_cast(type)) return VectorType::get(sVectorType.getShape(), i1Type, sVectorType.getNumScalableDims()); return nullptr; diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -42,7 +42,7 @@ auto executeOp = (*this)->getParentOfType(); auto types = llvm::map_range(executeOp.getBodyResults(), [](const OpResult &result) { - return result.getType().cast().getValueType(); + return llvm::cast(result.getType()).getValueType(); }); if (getOperandTypes() != types) @@ -71,7 +71,7 @@ bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) { const auto getValueOrTokenType = [](Type type) { - if (auto value = type.dyn_cast()) + if (auto value = llvm::dyn_cast(type)) return value.getValueType(); return type; }; @@ -118,7 +118,7 @@ bodyRegion->push_back(new Block); Block &bodyBlock = bodyRegion->front(); for (Value operand : operands) { - auto valueType = operand.getType().dyn_cast(); + auto valueType = llvm::dyn_cast(operand.getType()); bodyBlock.addArgument(valueType ? valueType.getValueType() : operand.getType(), operand.getLoc()); @@ -195,7 +195,7 @@ parser.parseColonType(valueTypes.emplace_back())) return failure(); - auto valueTy = valueTypes.back().dyn_cast(); + auto valueTy = llvm::dyn_cast(valueTypes.back()); unwrappedArgs.back().type = valueTy ? valueTy.getValueType() : Type(); return success(); }; @@ -234,7 +234,7 @@ LogicalResult ExecuteOp::verifyRegions() { // Unwrap async.execute value operands types. auto unwrappedTypes = llvm::map_range(getBodyOperands(), [](Value operand) { - return operand.getType().cast().getValueType(); + return llvm::cast(operand.getType()).getValueType(); }); // Verify that unwrapped argument types matches the body region arguments. @@ -285,7 +285,7 @@ result.attributes.append(attrs.begin(), attrs.end()); // Add unwrapped async.value type to the returned values types. - if (auto valueType = operand.getType().dyn_cast()) + if (auto valueType = llvm::dyn_cast(operand.getType())) result.addTypes(valueType.getValueType()); } @@ -295,7 +295,7 @@ return failure(); // Add unwrapped async.value type to the returned values types. - if (auto valueType = operandType.dyn_cast()) + if (auto valueType = llvm::dyn_cast(operandType)) resultType = valueType.getValueType(); return success(); @@ -310,11 +310,11 @@ Type argType = getOperand().getType(); // Awaiting on a token does not have any results. - if (argType.isa() && !getResultTypes().empty()) + if (llvm::isa(argType) && !getResultTypes().empty()) return emitOpError("awaiting on a token must have empty result"); // Awaiting on a value unwraps the async value type. - if (auto value = argType.dyn_cast()) { + if (auto value = llvm::dyn_cast(argType)) { if (*getResultType() != value.getValueType()) return emitOpError() << "result type " << *getResultType() << " does not match async value type " @@ -375,12 +375,12 @@ for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) { auto type = resultTypes[i]; - if (!type.isa() && !type.isa()) + if (!llvm::isa(type) && !llvm::isa(type)) return emitOpError() << "result type must be async value type or async " "token type, but got " << type; // We only allow AsyncToken appear as the first return value - if (type.isa() && i != 0) { + if (llvm::isa(type) && i != 0) { return emitOpError() << " results' (optional) async token type is expected " "to appear as the 1st return value, but got " @@ -446,7 +446,7 @@ // Get the underlying value types from async types returned from the // parent `async.func` operation. auto types = llvm::map_range(resultTypes, [](const Type &result) { - return result.cast().getValueType(); + return llvm::cast(result).getValueType(); }); if (getOperandTypes() != types) diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -91,9 +91,9 @@ } Operation *bufferization::getOwnerOfValue(Value value) { - if (auto opResult = value.dyn_cast()) + if (auto opResult = llvm::dyn_cast(value)) return opResult.getDefiningOp(); - return value.cast().getOwner()->getParentOp(); + return llvm::cast(value).getOwner()->getParentOp(); } bool bufferization::allocationDoesNotEscape(OpResult opResult) { @@ -109,7 +109,7 @@ return false; auto attr = op->getAttrOfType(BufferizationDialect::kEscapeAttrName); - return !attr[opResult.getResultNumber()].cast().getValue(); + return !llvm::cast(attr[opResult.getResultNumber()]).getValue(); } /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the @@ -119,31 +119,31 @@ OpBuilder &b, Location loc, Value shapedValue, bool escape, const BufferizationOptions &options, bool copy) { Value tensor; - if (shapedValue.getType().isa()) { + if (llvm::isa(shapedValue.getType())) { tensor = shapedValue; - } else if (shapedValue.getType().isa()) { + } else if (llvm::isa(shapedValue.getType())) { tensor = b.create(loc, shapedValue); - } else if (shapedValue.getType().isa() || - shapedValue.getType().isa()) { + } else if (llvm::isa(shapedValue.getType()) || + llvm::isa(shapedValue.getType())) { return getOwnerOfValue(shapedValue) ->emitError("copying of unranked tensors is not implemented"); } else { llvm_unreachable("expected RankedTensorType or MemRefType"); } - RankedTensorType tensorType = tensor.getType().cast(); + RankedTensorType tensorType = llvm::cast(tensor.getType()); SmallVector dynamicSizes; if (!copy) { // Compute the dynamic part of the shape. // First try to query the shape via ReifyRankedShapedTypeOpInterface. bool reifiedShapes = false; - if (shapedValue.getType().isa() && - shapedValue.isa()) { + if (llvm::isa(shapedValue.getType()) && + llvm::isa(shapedValue)) { ReifiedRankedShapedTypeDims resultDims; if (succeeded( reifyResultShapes(b, shapedValue.getDefiningOp(), resultDims))) { reifiedShapes = true; auto &shape = - resultDims[shapedValue.cast().getResultNumber()]; + resultDims[llvm::cast(shapedValue).getResultNumber()]; for (const auto &dim : enumerate(tensorType.getShape())) if (ShapedType::isDynamic(dim.value())) dynamicSizes.push_back(shape[dim.index()].get()); @@ -188,11 +188,11 @@ // Find all out-of-place OpOperands. for (OpOperand &opOperand : op->getOpOperands()) { Type operandType = opOperand.get().getType(); - if (!operandType.isa()) + if (!llvm::isa(operandType)) continue; if (state.isInPlace(opOperand)) continue; - if (operandType.isa()) + if (llvm::isa(operandType)) return op->emitError("copying of unranked tensors is not implemented"); AliasingOpResultList aliasingOpResults = @@ -209,9 +209,8 @@ !state.bufferizesToMemoryWrite(opOperand) && state.getAliasingOpOperands(aliasingOpResults.getAliases()[0].opResult) .getNumAliases() == 1 && - !aliasingOpResults.getAliases()[0] - .opResult.getType() - .isa()) { + !llvm::isa( + aliasingOpResults.getAliases()[0].opResult.getType())) { // The op itself does not write but may create exactly one alias. Instead // of copying the OpOperand, copy the OpResult. The OpResult can sometimes // be smaller than the OpOperand (e.g., in the case of an extract_slice, @@ -281,9 +280,9 @@ AnalysisState analysisState(options); if (op->hasAttr(BufferizationDialect::kEscapeAttrName)) { // AllocTensorOp has one result. - ArrayAttr escapeAttr = - op->getAttr(BufferizationDialect::kEscapeAttrName).cast(); - return !escapeAttr[0].cast().getValue(); + ArrayAttr escapeAttr = llvm::cast( + op->getAttr(BufferizationDialect::kEscapeAttrName)); + return !llvm::cast(escapeAttr[0]).getValue(); } // No "escape" annotation found. @@ -335,8 +334,8 @@ BaseMemRefType defaultUnknownTypeConverter(Value value, Attribute memorySpace, const BufferizationOptions &options) { - return getMemRefTypeWithFullyDynamicLayout(value.getType().cast(), - memorySpace); + return getMemRefTypeWithFullyDynamicLayout( + llvm::cast(value.getType()), memorySpace); } } // namespace @@ -394,7 +393,7 @@ //===----------------------------------------------------------------------===// static void setInsertionPointAfter(OpBuilder &b, Value value) { - if (auto bbArg = value.dyn_cast()) { + if (auto bbArg = llvm::dyn_cast(value)) { b.setInsertionPointToStart(bbArg.getOwner()); } else { b.setInsertionPointAfter(value.getDefiningOp()); @@ -463,7 +462,7 @@ } bool AnalysisState::bufferizesToMemoryWrite(Value value) const { - auto opResult = value.dyn_cast(); + auto opResult = llvm::dyn_cast(value); if (!opResult) return true; auto bufferizableOp = getOptions().dynCastBufferizableOp(value); @@ -476,7 +475,7 @@ /// read. Also takes into account ops that create an alias but do not read by /// themselves (e.g., ExtractSliceOp). bool AnalysisState::isValueRead(Value value) const { - assert(value.getType().isa() && "expected TensorType"); + assert(llvm::isa(value.getType()) && "expected TensorType"); SmallVector workingSet; for (OpOperand &use : value.getUses()) workingSet.push_back(&use); @@ -512,13 +511,13 @@ continue; } - if (value.isa()) { + if (llvm::isa(value)) { if (alwaysIncludeLeaves) result.insert(value); continue; } - OpResult opResult = value.cast(); + OpResult opResult = llvm::cast(value); BufferizableOpInterface bufferizableOp = options.dynCastBufferizableOp(opResult.getDefiningOp()); AliasingOpOperandList aliases = getAliasingOpOperands(opResult); @@ -658,8 +657,8 @@ // bufferization.to_memref is not allowed to change the rank. static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { #ifndef NDEBUG - auto rankedTensorType = tensor.getType().dyn_cast(); - assert((!rankedTensorType || memrefType.cast().getRank() == + auto rankedTensorType = llvm::dyn_cast(tensor.getType()); + assert((!rankedTensorType || llvm::cast(memrefType).getRank() == rankedTensorType.getRank()) && "to_memref would be invalid: mismatching ranks"); #endif @@ -668,7 +667,7 @@ FailureOr bufferization::getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options) { #ifndef NDEBUG - auto tensorType = value.getType().dyn_cast(); + auto tensorType = llvm::dyn_cast(value.getType()); assert(tensorType && "unexpected non-tensor type"); #endif // NDEBUG @@ -699,7 +698,8 @@ FailureOr bufferization::getBufferType( Value value, const BufferizationOptions &options, const DenseMap &fixedTypes) { - assert(value.getType().isa() && "unexpected non-tensor type"); + assert(llvm::isa(value.getType()) && + "unexpected non-tensor type"); // If the `value` is in `fixedTypes`, return the mapped type. const auto &it = fixedTypes.find(value); @@ -731,11 +731,11 @@ SmallVector replacements; for (OpResult opResult : op->getOpResults()) { Value replacement = values[opResult.getResultNumber()]; - if (opResult.getType().isa()) { + if (llvm::isa(opResult.getType())) { // The OpResult is a tensor. Such values are replaced with memrefs during // bufferization. - assert((replacement.getType().isa() || - replacement.getType().isa()) && + assert((llvm::isa(replacement.getType()) || + llvm::isa(replacement.getType())) && "tensor op result should be replaced with a memref value"); // The existing uses of the OpResult still expect a tensor. Insert a // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually @@ -797,7 +797,7 @@ //===----------------------------------------------------------------------===// bool bufferization::isFunctionArgument(Value value) { - auto bbArg = value.dyn_cast(); + auto bbArg = llvm::dyn_cast(value); if (!bbArg) return false; return isa(bbArg.getOwner()->getParentOp()); @@ -807,17 +807,18 @@ const BufferizationOptions &options, MemRefLayoutAttrInterface layout, Attribute memorySpace) { - auto tensorType = value.getType().cast(); + auto tensorType = llvm::cast(value.getType()); // Case 1: Unranked memref type. - if (auto unrankedTensorType = tensorType.dyn_cast()) { + if (auto unrankedTensorType = + llvm::dyn_cast(tensorType)) { assert(!layout && "UnrankedTensorType cannot have a layout map"); return UnrankedMemRefType::get(unrankedTensorType.getElementType(), memorySpace); } // Case 2: Ranked memref type with specified layout. - auto rankedTensorType = tensorType.cast(); + auto rankedTensorType = llvm::cast(tensorType); if (layout) { return MemRefType::get(rankedTensorType.getShape(), rankedTensorType.getElementType(), layout, @@ -831,13 +832,14 @@ bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace) { // Case 1: Unranked memref type. - if (auto unrankedTensorType = tensorType.dyn_cast()) { + if (auto unrankedTensorType = + llvm::dyn_cast(tensorType)) { return UnrankedMemRefType::get(unrankedTensorType.getElementType(), memorySpace); } // Case 2: Ranked memref type. - auto rankedTensorType = tensorType.cast(); + auto rankedTensorType = llvm::cast(tensorType); int64_t dynamicOffset = ShapedType::kDynamic; SmallVector dynamicStrides(rankedTensorType.getRank(), ShapedType::kDynamic); @@ -854,13 +856,14 @@ bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace) { // Case 1: Unranked memref type. - if (auto unrankedTensorType = tensorType.dyn_cast()) { + if (auto unrankedTensorType = + llvm::dyn_cast(tensorType)) { return UnrankedMemRefType::get(unrankedTensorType.getElementType(), memorySpace); } // Case 2: Ranked memref type. - auto rankedTensorType = tensorType.cast(); + auto rankedTensorType = llvm::cast(tensorType); MemRefLayoutAttrInterface layout = {}; return MemRefType::get(rankedTensorType.getShape(), rankedTensorType.getElementType(), layout, @@ -943,7 +946,7 @@ Operation *op = opResult.getDefiningOp(); SmallVector result; for (OpOperand &opOperand : op->getOpOperands()) { - if (!opOperand.get().getType().isa()) + if (!llvm::isa(opOperand.get().getType())) continue; AliasingOpResultList aliasingOpResults = state.getAliasingOpResults(opOperand); @@ -957,15 +960,15 @@ FailureOr bufferization::detail::defaultGetBufferType( Value value, const BufferizationOptions &options, const DenseMap &fixedTypes) { - assert(value.getType().isa() && "expected tensor type"); + assert(llvm::isa(value.getType()) && "expected tensor type"); // No further analysis is possible for a block argument. - if (value.isa()) + if (llvm::isa(value)) return bufferization::getMemRefType(value, options); // Value is an OpResult. Operation *op = getOwnerOfValue(value); - auto opResult = value.cast(); + auto opResult = llvm::cast(value); AnalysisState state(options); AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult); if (aliases.getNumAliases() > 0 && @@ -1000,7 +1003,7 @@ // Conservatively assume that everything may be aliasing. AliasingOpOperandList r; for (OpOperand &operand : opResult.getDefiningOp()->getOpOperands()) - if (operand.get().getType().isa()) + if (llvm::isa(operand.get().getType())) r.addAlias({&operand, BufferRelation::Unknown, /*isDefinite=*/false}); return r; } @@ -1010,7 +1013,7 @@ // Conservatively assume that everything may be aliasing. AliasingOpResultList r; for (OpResult result : opOperand.getOwner()->getOpResults()) - if (result.getType().isa()) + if (llvm::isa(result.getType())) r.addAlias({result, BufferRelation::Unknown, /*isDefinite=*/false}); return r; } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp @@ -62,7 +62,7 @@ Operation *op, unsigned /*regionIndex*/, unsigned argIndex, NamedAttribute attr) { if (attr.getName() == kWritableAttrName) { - if (!attr.getValue().isa()) { + if (!llvm::isa(attr.getValue())) { return op->emitError() << "'" << kWritableAttrName << "' is expected to be a boolean attribute"; } @@ -75,11 +75,11 @@ return success(); } if (attr.getName() == kBufferAccessAttrName) { - if (!attr.getValue().isa()) { + if (!llvm::isa(attr.getValue())) { return op->emitError() << "'" << kBufferAccessAttrName << "' is expected to be a string attribute"; } - StringRef str = attr.getValue().cast().getValue(); + StringRef str = llvm::cast(attr.getValue()).getValue(); if (str != "none" && str != "read" && str != "write" && str != "read-write") return op->emitError() << "invalid value for '" << kBufferAccessAttrName << "'"; @@ -89,7 +89,7 @@ return success(); } if (attr.getName() == kBufferLayoutAttrName) { - if (!attr.getValue().isa()) { + if (!llvm::isa(attr.getValue())) { return op->emitError() << "'" << kBufferLayoutAttrName << "' is expected to be a affine map attribute"; } @@ -109,7 +109,7 @@ using bufferization::BufferizableOpInterface; if (attr.getName() == kEscapeAttrName) { - auto arrayAttr = attr.getValue().dyn_cast(); + auto arrayAttr = llvm::dyn_cast(attr.getValue()); if (!arrayAttr) return op->emitError() << "'" << kEscapeAttrName << "' is expected to be a bool array attribute"; @@ -124,13 +124,13 @@ << "'" << kEscapeAttrName << "' only valid on bufferizable ops"; for (const auto &it : llvm::enumerate(arrayAttr)) { auto attr = it.value(); - auto boolAttr = attr.dyn_cast(); + auto boolAttr = llvm::dyn_cast(attr); if (!boolAttr) return op->emitError() << "'" << kEscapeAttrName << "' is expected to be a bool array attribute"; if (!boolAttr.getValue()) continue; - if (!op->getResult(it.index()).getType().isa()) + if (!llvm::isa(op->getResult(it.index()).getType())) return op->emitError() << "'" << kEscapeAttrName << "' only valid for tensor results"; if (!bufferizableOp.bufferizesToAllocation(op->getOpResult(it.index()))) diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -27,7 +27,7 @@ FailureOr mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType destType) { - auto srcType = value.getType().cast(); + auto srcType = llvm::cast(value.getType()); // Element type, rank and memory space must match. if (srcType.getElementType() != destType.getElementType()) @@ -100,9 +100,9 @@ return success(); } - auto rankedSrcType = srcType.dyn_cast(); - auto rankedDestType = destType.dyn_cast(); - auto unrankedSrcType = srcType.dyn_cast(); + auto rankedSrcType = llvm::dyn_cast(srcType); + auto rankedDestType = llvm::dyn_cast(destType); + auto unrankedSrcType = llvm::dyn_cast(srcType); // Ranked memref -> Ranked memref cast. if (rankedSrcType && rankedDestType) { @@ -132,13 +132,13 @@ void mlir::bufferization::populateDynamicDimSizes( OpBuilder &b, Location loc, Value shapedValue, SmallVector &dynamicDims) { - auto shapedType = shapedValue.getType().cast(); + auto shapedType = llvm::cast(shapedValue.getType()); for (int64_t i = 0; i < shapedType.getRank(); ++i) { if (shapedType.isDynamicDim(i)) { - if (shapedType.isa()) { + if (llvm::isa(shapedType)) { dynamicDims.push_back(b.create(loc, shapedValue, i)); } else { - assert(shapedType.isa() && "expected tensor"); + assert(llvm::isa(shapedType) && "expected tensor"); dynamicDims.push_back(b.create(loc, shapedValue, i)); } } @@ -191,7 +191,7 @@ // Should the buffer be deallocated? bool dealloc = - shouldDeallocateOpResult(getResult().cast(), options); + shouldDeallocateOpResult(llvm::cast(getResult()), options); // Replace op. replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc); @@ -431,7 +431,7 @@ AllocTensorOp::getOperandSegmentSizeAttr()}); p << " : "; auto type = getResult().getType(); - if (auto validType = type.dyn_cast<::mlir::TensorType>()) + if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type)) p.printStrippedAttrOrType(validType); else p << type; @@ -620,8 +620,8 @@ toMemref.getOperand().getDefiningOp(); if (!tensorCastOperand) return failure(); - auto srcTensorType = - tensorCastOperand.getOperand().getType().dyn_cast(); + auto srcTensorType = llvm::dyn_cast( + tensorCastOperand.getOperand().getType()); if (!srcTensorType) return failure(); auto memrefType = MemRefType::get(srcTensorType.getShape(), diff --git a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp --- a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp @@ -34,7 +34,7 @@ Location loc) { if (complex::ConstantOp::isBuildableWith(value, type)) { return builder.create(loc, type, - value.cast()); + llvm::cast(value)); } return arith::ConstantOp::materialize(builder, value, type, loc); } @@ -46,16 +46,16 @@ ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::llvm::APFloat real, ::llvm::APFloat imag, ::mlir::Type type) { - if (!type.isa()) + if (!llvm::isa(type)) return emitError() << "complex attribute must be a complex type."; - Type elementType = type.cast().getElementType(); - if (!elementType.isa()) + Type elementType = llvm::cast(type).getElementType(); + if (!llvm::isa(elementType)) return emitError() << "element type of the complex attribute must be float like type."; const auto &typeFloatSemantics = - elementType.cast().getFloatSemantics(); + llvm::cast(elementType).getFloatSemantics(); if (&real.getSemantics() != &typeFloatSemantics) return emitError() << "type doesn't match the type implied by its `real` value"; @@ -67,7 +67,7 @@ } void complex::NumberAttr::print(AsmPrinter &printer) const { - printer << "<:" << getType().cast().getElementType() << " " + printer << "<:" << llvm::cast(getType()).getElementType() << " " << getReal() << ", " << getImag() << ">"; } diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -27,18 +27,18 @@ } bool ConstantOp::isBuildableWith(Attribute value, Type type) { - if (auto arrAttr = value.dyn_cast()) { - auto complexTy = type.dyn_cast(); + if (auto arrAttr = llvm::dyn_cast(value)) { + auto complexTy = llvm::dyn_cast(type); if (!complexTy || arrAttr.size() != 2) return false; auto complexEltTy = complexTy.getElementType(); - if (auto fre = arrAttr[0].dyn_cast()) { - auto im = arrAttr[1].dyn_cast(); + if (auto fre = llvm::dyn_cast(arrAttr[0])) { + auto im = llvm::dyn_cast(arrAttr[1]); return im && fre.getType() == complexEltTy && im.getType() == complexEltTy; } - if (auto ire = arrAttr[0].dyn_cast()) { - auto im = arrAttr[1].dyn_cast(); + if (auto ire = llvm::dyn_cast(arrAttr[0])) { + auto im = llvm::dyn_cast(arrAttr[1]); return im && ire.getType() == complexEltTy && im.getType() == complexEltTy; } @@ -55,8 +55,8 @@ } auto complexEltTy = getType().getElementType(); - auto re = arrayAttr[0].dyn_cast(); - auto im = arrayAttr[1].dyn_cast(); + auto re = llvm::dyn_cast(arrayAttr[0]); + auto im = llvm::dyn_cast(arrayAttr[1]); if (!re || !im) return emitOpError("requires attribute's elements to be float attributes"); if (complexEltTy != re.getType() || complexEltTy != im.getType()) { @@ -129,8 +129,8 @@ // complex.add(a, complex.constant<0.0, 0.0>) -> a if (auto constantOp = getRhs().getDefiningOp()) { auto arrayAttr = constantOp.getValue(); - if (arrayAttr[0].cast().getValue().isZero() && - arrayAttr[1].cast().getValue().isZero()) { + if (llvm::cast(arrayAttr[0]).getValue().isZero() && + llvm::cast(arrayAttr[1]).getValue().isZero()) { return getLhs(); } } @@ -151,8 +151,8 @@ // complex.sub(a, complex.constant<0.0, 0.0>) -> a if (auto constantOp = getRhs().getDefiningOp()) { auto arrayAttr = constantOp.getValue(); - if (arrayAttr[0].cast().getValue().isZero() && - arrayAttr[1].cast().getValue().isZero()) { + if (llvm::cast(arrayAttr[0]).getValue().isZero() && + llvm::cast(arrayAttr[1]).getValue().isZero()) { return getLhs(); } } diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -125,7 +125,7 @@ // Otherwise, we need to remap any argument operands. for (Value operand : operands) { - BlockArgument argOperand = operand.dyn_cast(); + BlockArgument argOperand = llvm::dyn_cast(operand); if (argOperand && argOperand.getOwner() == successor) argStorage.push_back(successorOperands[argOperand.getArgNumber()]); else @@ -442,7 +442,8 @@ } Block *CondBranchOp::getSuccessorForOperands(ArrayRef operands) { - if (IntegerAttr condAttr = operands.front().dyn_cast_or_null()) + if (IntegerAttr condAttr = + llvm::dyn_cast_or_null(operands.front())) return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest(); return nullptr; } @@ -601,7 +602,7 @@ return getDefaultDestination(); SuccessorRange caseDests = getCaseDestinations(); - if (auto value = operands.front().dyn_cast_or_null()) { + if (auto value = llvm::dyn_cast_or_null(operands.front())) { for (const auto &it : llvm::enumerate(caseValues->getValues())) if (it.value() == value.getValue()) return caseDests[it.index()]; diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp --- a/mlir/lib/Dialect/DLTI/DLTI.cpp +++ b/mlir/lib/Dialect/DLTI/DLTI.cpp @@ -215,7 +215,7 @@ typeSample.getContext()->getLoadedDialect() && "unexpected data layout entry for built-in type"); - auto interface = typeSample.cast(); + auto interface = llvm::cast(typeSample); if (!interface.areCompatible(entriesForType.lookup(kvp.first), kvp.second)) return failure(); @@ -250,7 +250,7 @@ // Only combine with attributes of the same kind. // TODO: reconsider this when the need arises. if (llvm::any_of(specs, [](DataLayoutSpecInterface spec) { - return !spec.isa(); + return !llvm::isa(spec); })) return {}; @@ -334,7 +334,7 @@ Location loc) const final { StringRef entryName = entry.getKey().get().strref(); if (entryName == DLTIDialect::kDataLayoutEndiannessKey) { - auto value = entry.getValue().dyn_cast(); + auto value = llvm::dyn_cast(entry.getValue()); if (value && (value.getValue() == DLTIDialect::kDataLayoutEndiannessBig || value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle)) @@ -383,7 +383,7 @@ LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { if (attr.getName() == DLTIDialect::kDataLayoutAttrName) { - if (!attr.getValue().isa()) { + if (!llvm::isa(attr.getValue())) { return op->emitError() << "'" << DLTIDialect::kDataLayoutAttrName << "' is expected to be a #dlti.dl_spec attribute"; } diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -73,10 +73,10 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { Type input = inputs.front(), output = outputs.front(); - return ((input.isa()) && - (output.isa())); + return ((llvm::isa(input)) && + (llvm::isa(output))); } //===----------------------------------------------------------------------===// @@ -90,8 +90,8 @@ if (std::optional argsAttr = getArgs()) { for (Attribute arg : *argsAttr) { - auto intAttr = arg.dyn_cast(); - if (intAttr && intAttr.getType().isa()) { + auto intAttr = llvm::dyn_cast(arg); + if (intAttr && llvm::isa(intAttr.getType())) { int64_t index = intAttr.getInt(); // Args with elements of type index must be in range // [0..operands.size). @@ -99,7 +99,8 @@ return emitOpError("index argument is out of range"); // Args with elements of type ArrayAttr must have a type. - } else if (arg.isa() /*&& arg.getType().isa()*/) { + } else if (llvm::isa( + arg) /*&& arg.getType().isa()*/) { // FIXME: Array attributes never have types return emitOpError("array argument has no type"); } @@ -108,7 +109,7 @@ if (std::optional templateArgsAttr = getTemplateArgs()) { for (Attribute tArg : *templateArgsAttr) { - if (!tArg.isa()) + if (!llvm::isa(tArg)) return emitOpError("template argument has invalid type"); } } @@ -122,17 +123,17 @@ /// The constant op requires that the attribute's type matches the return type. LogicalResult emitc::ConstantOp::verify() { - if (getValueAttr().isa()) + if (llvm::isa(getValueAttr())) return success(); // Value must not be empty - StringAttr strAttr = getValueAttr().dyn_cast(); + StringAttr strAttr = llvm::dyn_cast(getValueAttr()); if (strAttr && strAttr.getValue().empty()) return emitOpError() << "value must not be empty"; auto value = cast(getValueAttr()); Type type = getType(); - if (!value.getType().isa() && type != value.getType()) + if (!llvm::isa(value.getType()) && type != value.getType()) return emitOpError() << "requires attribute's type (" << value.getType() << ") to match op's return type (" << type << ")"; return success(); @@ -183,12 +184,12 @@ /// The variable op requires that the attribute's type matches the return type. LogicalResult emitc::VariableOp::verify() { - if (getValueAttr().isa()) + if (llvm::isa(getValueAttr())) return success(); auto value = cast(getValueAttr()); Type type = getType(); - if (!value.getType().isa() && type != value.getType()) + if (!llvm::isa(value.getType()) && type != value.getType()) return emitOpError() << "requires attribute's type (" << value.getType() << ") to match op's return type (" << type << ")"; return success(); diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp --- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp +++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp @@ -112,7 +112,7 @@ Type type, Location loc) { if (ConstantOp::isBuildableWith(value, type)) return builder.create(loc, type, - value.cast()); + llvm::cast(value)); return nullptr; } @@ -209,7 +209,7 @@ } bool ConstantOp::isBuildableWith(Attribute value, Type type) { - return value.isa() && type.isa(); + return llvm::isa(value) && llvm::isa(type); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -220,7 +220,7 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { - if (!attr.getValue().isa() || + if (!llvm::isa(attr.getValue()) || attr.getName() != getContainerModuleAttrName()) return success(); @@ -368,14 +368,14 @@ ArrayRef attributions, gpu::AddressSpace memorySpace) { for (Value v : attributions) { - auto type = v.getType().dyn_cast(); + auto type = llvm::dyn_cast(v.getType()); if (!type) return op->emitOpError() << "expected memref type in attribution"; // We can only verify the address space if it hasn't already been lowered // from the AddressSpaceAttr to a target-specific numeric value. auto addressSpace = - type.getMemorySpace().dyn_cast_or_null(); + llvm::dyn_cast_or_null(type.getMemorySpace()); if (!addressSpace) continue; if (addressSpace.getValue() != memorySpace) @@ -395,7 +395,7 @@ return (opName != gpu::AllReduceOperation::AND && opName != gpu::AllReduceOperation::OR && opName != gpu::AllReduceOperation::XOR) || - resType.isa(); + llvm::isa(resType); } LogicalResult gpu::AllReduceOp::verifyRegions() { @@ -1186,7 +1186,7 @@ size_t attributionIndex = pair.index(); DictionaryAttr attrs; if (attributes && attributionIndex < attributes.size()) - attrs = attributes[attributionIndex].cast(); + attrs = llvm::cast(attributes[attributionIndex]); if (attrs) p.printOptionalAttrDict(attrs.getValue()); }); @@ -1221,10 +1221,10 @@ static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName) { - auto allAttrs = op->getAttr(attrName).dyn_cast_or_null(); + auto allAttrs = llvm::dyn_cast_or_null(op->getAttr(attrName)); if (!allAttrs || index >= allAttrs.size()) return DictionaryAttr(); - return allAttrs[index].cast(); + return llvm::cast(allAttrs[index]); } DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) { @@ -1238,7 +1238,7 @@ static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName) { MLIRContext *ctx = op.getContext(); - auto allAttrs = op->getAttr(attrName).dyn_cast_or_null(); + auto allAttrs = llvm::dyn_cast_or_null(op->getAttr(attrName)); SmallVector elements; if (allAttrs) elements.append(allAttrs.begin(), allAttrs.end()); @@ -1379,7 +1379,7 @@ auto maybeAttr = op->getAttr(attrName); if (!maybeAttr) return success(); - auto array = maybeAttr.dyn_cast(); + auto array = llvm::dyn_cast(maybeAttr); if (!array) return op.emitOpError(attrName + " must be a dense i32 array"); if (array.size() != 3) @@ -1536,9 +1536,9 @@ LogicalResult SubgroupMmaLoadMatrixOp::verify() { auto srcType = getSrcMemref().getType(); auto resType = getRes().getType(); - auto resMatrixType = resType.cast(); + auto resMatrixType = llvm::cast(resType); auto operand = resMatrixType.getOperand(); - auto srcMemrefType = srcType.cast(); + auto srcMemrefType = llvm::cast(srcType); if (!isLastMemrefDimUnitStride(srcMemrefType)) return emitError( @@ -1558,8 +1558,8 @@ LogicalResult SubgroupMmaStoreMatrixOp::verify() { auto srcType = getSrc().getType(); auto dstType = getDstMemref().getType(); - auto srcMatrixType = srcType.cast(); - auto dstMemrefType = dstType.cast(); + auto srcMatrixType = llvm::cast(srcType); + auto dstMemrefType = llvm::cast(dstType); if (!isLastMemrefDimUnitStride(dstMemrefType)) return emitError( @@ -1579,9 +1579,9 @@ LogicalResult SubgroupMmaComputeOp::verify() { enum OperandMap { A, B, C }; SmallVector opTypes; - opTypes.push_back(getOpA().getType().cast()); - opTypes.push_back(getOpB().getType().cast()); - opTypes.push_back(getOpC().getType().cast()); + opTypes.push_back(llvm::cast(getOpA().getType())); + opTypes.push_back(llvm::cast(getOpB().getType())); + opTypes.push_back(llvm::cast(getOpC().getType())); if (!opTypes[A].getOperand().equals("AOp") || !opTypes[B].getOperand().equals("BOp") || @@ -1688,7 +1688,7 @@ //===----------------------------------------------------------------------===// LogicalResult AllocOp::verify() { - auto memRefType = getMemref().getType().cast(); + auto memRefType = llvm::cast(getMemref().getType()); if (static_cast(getDynamicSizes().size()) != memRefType.getNumDynamicDims()) @@ -1719,7 +1719,7 @@ if (!index) return failure(); - auto memrefType = dimOp.getSource().getType().dyn_cast(); + auto memrefType = llvm::dyn_cast(dimOp.getSource().getType()); if (!memrefType || !memrefType.isDynamicDim(index.value())) return failure(); diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp --- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp +++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp @@ -39,7 +39,8 @@ // Materialize integer attributes as `index`. if (auto indexValue = dyn_cast(value)) { - if (!indexValue.getType().isa() || !type.isa()) + if (!llvm::isa(indexValue.getType()) || + !llvm::isa(type)) return nullptr; assert(indexValue.getValue().getBitWidth() == IndexType::kInternalStorageBitWidth); @@ -399,7 +400,8 @@ //===----------------------------------------------------------------------===// bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) { - return lhsTypes.front().isa() != rhsTypes.front().isa(); + return llvm::isa(lhsTypes.front()) != + llvm::isa(rhsTypes.front()); } //===----------------------------------------------------------------------===// @@ -407,7 +409,8 @@ //===----------------------------------------------------------------------===// bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) { - return lhsTypes.front().isa() != rhsTypes.front().isa(); + return llvm::isa(lhsTypes.front()) != + llvm::isa(rhsTypes.front()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -179,7 +179,7 @@ //===----------------------------------------------------------------------===// void AllocaOp::print(OpAsmPrinter &p) { - Type elemTy = getType().cast().getElementType(); + Type elemTy = llvm::cast(getType()).getElementType(); if (!elemTy) elemTy = *getElemType(); @@ -220,7 +220,7 @@ std::optional alignmentAttr = result.attributes.getNamed("alignment"); if (alignmentAttr.has_value()) { - auto alignmentInt = alignmentAttr->getValue().dyn_cast(); + auto alignmentInt = llvm::dyn_cast(alignmentAttr->getValue()); if (!alignmentInt) return parser.emitError(parser.getNameLoc(), "expected integer alignment"); @@ -229,7 +229,7 @@ } // Extract the result type from the trailing function type. - auto funcType = type.dyn_cast(); + auto funcType = llvm::dyn_cast(type); if (!funcType || funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) return parser.emitError( @@ -240,7 +240,7 @@ return failure(); Type resultType = funcType.getResult(0); - if (auto ptrResultType = resultType.dyn_cast()) { + if (auto ptrResultType = llvm::dyn_cast(resultType)) { if (ptrResultType.isOpaque()) result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType)); } @@ -266,7 +266,7 @@ } LogicalResult AllocaOp::verify() { - return verifyOpaquePtr(getOperation(), getType().cast(), + return verifyOpaquePtr(getOperation(), llvm::cast(getType()), getElemType()); } @@ -410,7 +410,7 @@ size_t index = 0; llvm::interleave( - llvm::zip(caseValues.cast(), caseDestinations), + llvm::zip(llvm::cast(caseValues), caseDestinations), [&](auto i) { p << " "; p << std::get<0>(i).getLimitedValue(); @@ -457,11 +457,11 @@ /// Returns the elemental type of any LLVM-compatible vector type or self. static Type extractVectorElementType(Type type) { - if (auto vectorType = type.dyn_cast()) + if (auto vectorType = llvm::dyn_cast(type)) return vectorType.getElementType(); - if (auto scalableVectorType = type.dyn_cast()) + if (auto scalableVectorType = llvm::dyn_cast(type)) return scalableVectorType.getElementType(); - if (auto fixedVectorType = type.dyn_cast()) + if (auto fixedVectorType = llvm::dyn_cast(type)) return fixedVectorType.getElementType(); return type; } @@ -470,7 +470,7 @@ Value basePtr, ArrayRef indices, bool inbounds, ArrayRef attributes) { auto ptrType = - extractVectorElementType(basePtr.getType()).cast(); + llvm::cast(extractVectorElementType(basePtr.getType())); assert(!ptrType.isOpaque() && "expected non-opaque pointer, provide elementType explicitly when " "opaque pointers are used"); @@ -543,8 +543,7 @@ result.addAttribute(getInboundsAttrName(result.name), builder.getUnitAttr()); } - if (extractVectorElementType(basePtr.getType()) - .cast() + if (llvm::cast(extractVectorElementType(basePtr.getType())) .isOpaque()) result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType)); result.addOperands(basePtr); @@ -695,7 +694,7 @@ LogicalResult LLVM::GEPOp::verify() { if (failed(verifyOpaquePtr( getOperation(), - extractVectorElementType(getType()).cast(), + llvm::cast(extractVectorElementType(getType())), getElemType()))) return failure(); @@ -716,8 +715,8 @@ if (std::optional elemType = getElemType()) return *elemType; - return extractVectorElementType(getBase().getType()) - .cast() + return llvm::cast( + extractVectorElementType(getBase().getType())) .getElementType(); } @@ -729,16 +728,16 @@ /// integer and float types with limited bit width are supported. Additionally, /// depending on the operation pointers may be supported as well. static bool isTypeCompatibleWithAtomicOp(Type type, bool isPointerTypeAllowed) { - if (type.isa()) + if (llvm::isa(type)) return isPointerTypeAllowed; std::optional bitWidth; - if (auto floatType = type.dyn_cast()) { + if (auto floatType = llvm::dyn_cast(type)) { if (!isCompatibleFloatingPointType(type)) return false; bitWidth = floatType.getWidth(); } - if (auto integerType = type.dyn_cast()) + if (auto integerType = llvm::dyn_cast(type)) bitWidth = integerType.getWidth(); // The type is neither an integer, float, or pointer type. if (!bitWidth) @@ -777,7 +776,7 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Value addr, unsigned alignment, bool isVolatile, bool isNonTemporal) { - auto type = addr.getType().cast().getElementType(); + auto type = llvm::cast(addr.getType()).getElementType(); assert(type && "must provide explicit element type to the constructor " "when the pointer type is opaque"); build(builder, state, type, addr, alignment, isVolatile, isNonTemporal); @@ -801,7 +800,7 @@ // std::nullopt if the given type is not the pointer type. static std::optional getLoadStoreElementType(OpAsmParser &parser, Type type, SMLoc trailingTypeLoc) { - auto llvmTy = type.dyn_cast(); + auto llvmTy = llvm::dyn_cast(type); if (!llvmTy) { parser.emitError(trailingTypeLoc, "expected LLVM pointer type"); return std::nullopt; @@ -919,7 +918,7 @@ ValueRange args) { SmallVector results; Type resultType = func.getFunctionType().getReturnType(); - if (!resultType.isa()) + if (!llvm::isa(resultType)) results.push_back(resultType); build(builder, state, results, SymbolRefAttr::get(func), args, nullptr, nullptr); @@ -964,7 +963,7 @@ if (!getNumOperands()) return emitOpError( "must have either a `callee` attribute or at least an operand"); - auto ptrType = getOperand(0).getType().dyn_cast(); + auto ptrType = llvm::dyn_cast(getOperand(0).getType()); if (!ptrType) return emitOpError("indirect call expects a pointer as callee: ") << getOperand(0).getType(); @@ -988,7 +987,7 @@ fnType = fn.getFunctionType(); } - LLVMFunctionType funcType = fnType.dyn_cast(); + LLVMFunctionType funcType = llvm::dyn_cast(fnType); if (!funcType) return emitOpError("callee does not have a functional type: ") << fnType; @@ -1023,11 +1022,11 @@ << " != " << funcType.getParamType(i); if (getNumResults() == 0 && - !funcType.getReturnType().isa()) + !llvm::isa(funcType.getReturnType())) return emitOpError() << "expected function call to produce a value"; if (getNumResults() != 0 && - funcType.getReturnType().isa()) + llvm::isa(funcType.getReturnType())) return emitOpError() << "calling function with void result must not produce values"; @@ -1083,7 +1082,7 @@ return parser.emitError(trailingTypesLoc, "expected indirect call to have 2 trailing types"); - auto funcType = types.pop_back_val().dyn_cast(); + auto funcType = llvm::dyn_cast(types.pop_back_val()); if (!funcType) return parser.emitError(trailingTypesLoc, "expected trailing function type"); @@ -1091,7 +1090,7 @@ return parser.emitError(trailingTypesLoc, "expected function with 0 or 1 result"); if (funcType.getNumResults() == 1 && - funcType.getResult(0).isa()) + llvm::isa(funcType.getResult(0))) return parser.emitError(trailingTypesLoc, "expected a non-void result type"); @@ -1292,7 +1291,7 @@ for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) { value = getOperand(idx); - bool isFilter = value.getType().isa(); + bool isFilter = llvm::isa(value.getType()); if (isFilter) { // FIXME: Verify filter clauses when arrays are appropriately handled } else { @@ -1324,7 +1323,7 @@ for (auto value : getOperands()) { // Similar to llvm - if clause is an array type then it is filter // clause else catch clause - bool isArrayTy = value.getType().isa(); + bool isArrayTy = llvm::isa(value.getType()); p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : " << value.getType() << ") "; } @@ -1383,13 +1382,13 @@ // structures. Check the position index before accessing, it is supposed to // be in bounds. for (int64_t idx : position) { - if (auto arrayType = llvmType.dyn_cast()) { + if (auto arrayType = llvm::dyn_cast(llvmType)) { if (idx < 0 || static_cast(idx) >= arrayType.getNumElements()) { emitError("position out of bounds: ") << idx; return {}; } llvmType = arrayType.getElementType(); - } else if (auto structType = llvmType.dyn_cast()) { + } else if (auto structType = llvm::dyn_cast(llvmType)) { if (idx < 0 || static_cast(idx) >= structType.getBody().size()) { emitError("position out of bounds: ") << idx; @@ -1409,10 +1408,10 @@ static Type getInsertExtractValueElementType(Type llvmType, ArrayRef position) { for (int64_t idx : position) { - if (auto structType = llvmType.dyn_cast()) + if (auto structType = llvm::dyn_cast(llvmType)) llvmType = structType.getBody()[idx]; else - llvmType = llvmType.cast().getElementType(); + llvmType = llvm::cast(llvmType).getElementType(); } return llvmType; } @@ -1519,7 +1518,7 @@ return success(); Type expectedType = parent.getFunctionType().getReturnType(); - if (expectedType.isa()) { + if (llvm::isa(expectedType)) { if (!getArg()) return success(); InFlightDiagnostic diag = emitOpError("expected no operands"); @@ -1527,7 +1526,7 @@ return diag; } if (!getArg()) { - if (expectedType.isa()) + if (llvm::isa(expectedType)) return success(); InFlightDiagnostic diag = emitOpError("expected 1 operand"); diag.attachNote(parent->getLoc()) << "when returning from function"; @@ -1664,7 +1663,7 @@ getVisibility_AttrName()}); // Print the trailing type unless it's a string global. - if (getValueOrNull().dyn_cast_or_null()) + if (llvm::dyn_cast_or_null(getValueOrNull())) return; p << " : " << getType(); @@ -1779,7 +1778,7 @@ Region &initRegion = *result.addRegion(); if (types.empty()) { - if (auto strAttr = value.dyn_cast_or_null()) { + if (auto strAttr = llvm::dyn_cast_or_null(value)) { MLIRContext *context = parser.getContext(); auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8), strAttr.getValue().size()); @@ -1802,15 +1801,15 @@ } static bool isZeroAttribute(Attribute value) { - if (auto intValue = value.dyn_cast()) + if (auto intValue = llvm::dyn_cast(value)) return intValue.getValue().isZero(); - if (auto fpValue = value.dyn_cast()) + if (auto fpValue = llvm::dyn_cast(value)) return fpValue.getValue().isZero(); - if (auto splatValue = value.dyn_cast()) + if (auto splatValue = llvm::dyn_cast(value)) return isZeroAttribute(splatValue.getSplatValue()); - if (auto elementsValue = value.dyn_cast()) + if (auto elementsValue = llvm::dyn_cast(value)) return llvm::all_of(elementsValue.getValues(), isZeroAttribute); - if (auto arrayValue = value.dyn_cast()) + if (auto arrayValue = llvm::dyn_cast(value)) return llvm::all_of(arrayValue.getValue(), isZeroAttribute); return false; } @@ -1822,10 +1821,10 @@ if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp())) return emitOpError("must appear at the module level"); - if (auto strAttr = getValueOrNull().dyn_cast_or_null()) { - auto type = getType().dyn_cast(); + if (auto strAttr = llvm::dyn_cast_or_null(getValueOrNull())) { + auto type = llvm::dyn_cast(getType()); IntegerType elementType = - type ? type.getElementType().dyn_cast() : nullptr; + type ? llvm::dyn_cast(type.getElementType()) : nullptr; if (!elementType || elementType.getWidth() != 8 || type.getNumElements() != strAttr.getValue().size()) return emitOpError( @@ -1844,7 +1843,7 @@ } if (getLinkage() == Linkage::Appending) { - if (!getType().isa()) { + if (!llvm::isa(getType())) { return emitOpError() << "expected array type for '" << stringifyLinkage(Linkage::Appending) << "' linkage"; @@ -1892,7 +1891,7 @@ LogicalResult GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { for (Attribute ctor : getCtors()) { - if (failed(verifySymbolAttrUse(ctor.cast(), *this, + if (failed(verifySymbolAttrUse(llvm::cast(ctor), *this, symbolTable))) return failure(); } @@ -1913,7 +1912,7 @@ LogicalResult GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { for (Attribute dtor : getDtors()) { - if (failed(verifySymbolAttrUse(dtor.cast(), *this, + if (failed(verifySymbolAttrUse(llvm::cast(dtor), *this, symbolTable))) return failure(); } @@ -2012,7 +2011,7 @@ if (argAttrs.empty()) return; - assert(type.cast().getNumParams() == argAttrs.size() && + assert(llvm::cast(type).getNumParams() == argAttrs.size() && "expected as many argument attribute lists as arguments"); function_interface_impl::addArgAndResultAttrs( builder, result, argAttrs, /*resultAttrs=*/std::nullopt, @@ -2143,7 +2142,7 @@ argTypes.push_back(fnType.getParamType(i)); Type returnType = fnType.getReturnType(); - if (!returnType.isa()) + if (!llvm::isa(returnType)) resTypes.push_back(returnType); function_interface_impl::printFunctionSignature(p, *this, argTypes, @@ -2251,8 +2250,8 @@ //===----------------------------------------------------------------------===// LogicalResult LLVM::ConstantOp::verify() { - if (StringAttr sAttr = getValue().dyn_cast()) { - auto arrayType = getType().dyn_cast(); + if (StringAttr sAttr = llvm::dyn_cast(getValue())) { + auto arrayType = llvm::dyn_cast(getType()); if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() || !arrayType.getElementType().isInteger(8)) { return emitOpError() << "expected array type of " @@ -2261,35 +2260,35 @@ } return success(); } - if (auto structType = getType().dyn_cast()) { + if (auto structType = llvm::dyn_cast(getType())) { if (structType.getBody().size() != 2 || structType.getBody()[0] != structType.getBody()[1]) { return emitError() << "expected struct type with two elements of the " "same type, the type of a complex constant"; } - auto arrayAttr = getValue().dyn_cast(); + auto arrayAttr = llvm::dyn_cast(getValue()); if (!arrayAttr || arrayAttr.size() != 2) { return emitOpError() << "expected array attribute with two elements, " "representing a complex constant"; } - auto re = arrayAttr[0].dyn_cast(); - auto im = arrayAttr[1].dyn_cast(); + auto re = llvm::dyn_cast(arrayAttr[0]); + auto im = llvm::dyn_cast(arrayAttr[1]); if (!re || !im || re.getType() != im.getType()) { return emitOpError() << "expected array attribute with two elements of the same type"; } Type elementType = structType.getBody()[0]; - if (!elementType - .isa()) { + if (!llvm::isa( + elementType)) { return emitError() << "expected struct element types to be floating point type or " "integer type"; } return success(); } - if (!getValue().isa()) + if (!llvm::isa(getValue())) return emitOpError() << "only supports integer, float, string or elements attributes"; return success(); @@ -2314,7 +2313,7 @@ } LogicalResult AtomicRMWOp::verify() { - auto ptrType = getPtr().getType().cast(); + auto ptrType = llvm::cast(getPtr().getType()); auto valType = getVal().getType(); if (!ptrType.isOpaque() && valType != ptrType.getElementType()) return emitOpError("expected LLVM IR element type for operand #0 to " @@ -2327,7 +2326,7 @@ if (!isTypeCompatibleWithAtomicOp(valType, /*isPointerTypeAllowed=*/false)) return emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); } else { - auto intType = valType.dyn_cast(); + auto intType = llvm::dyn_cast(valType); unsigned intBitWidth = intType ? intType.getWidth() : 0; if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64) @@ -2367,7 +2366,7 @@ } LogicalResult AtomicCmpXchgOp::verify() { - auto ptrType = getPtr().getType().cast(); + auto ptrType = llvm::cast(getPtr().getType()); if (!ptrType) return emitOpError("expected LLVM IR pointer type for operand #0"); auto valType = getVal().getType(); @@ -2421,10 +2420,10 @@ } LogicalResult LLVM::BitcastOp::verify() { - auto resultType = extractVectorElementType(getResult().getType()) - .dyn_cast(); - auto sourceType = - extractVectorElementType(getArg().getType()).dyn_cast(); + auto resultType = llvm::dyn_cast( + extractVectorElementType(getResult().getType())); + auto sourceType = llvm::dyn_cast( + extractVectorElementType(getArg().getType())); // If one of the types is a pointer (or vector of pointers), then // both source and result type have to be pointers. @@ -2435,7 +2434,8 @@ return success(); auto isVector = [](Type type) { - return type.isa(); + return llvm::isa( + type); }; // Due to bitcast requiring both operands to be of the same size, it is not @@ -2480,7 +2480,7 @@ // gep %x:T, 0 -> %x if (getBase().getType() == getType() && indices.size() == 1) - if (auto integer = indices[0].dyn_cast_or_null()) + if (auto integer = llvm::dyn_cast_or_null(indices[0])) if (integer.getValue().isZero()) return getBase(); @@ -2488,7 +2488,7 @@ bool changed = false; SmallVector gepArgs; for (auto iter : llvm::enumerate(indices)) { - auto integer = iter.value().dyn_cast_or_null(); + auto integer = llvm::dyn_cast_or_null(iter.value()); // Constant indices can only be int32_t, so if integer does not fit we // are forced to keep it dynamic, despite being a constant. if (!indices.isDynamicIndex(iter.index()) || !integer || @@ -2686,7 +2686,7 @@ SmallVectorImpl &operands = tbaaGraph[tdOp.getSymNameAttr()]->operands; for (Attribute attr : tdOp.getMembers()) { - StringAttr symbolRef = attr.cast().getAttr(); + StringAttr symbolRef = llvm::cast(attr).getAttr(); if (failed(verifyReference(op, symbolRef, tdOp.getMembersAttrName()))) return failure(); @@ -2888,7 +2888,7 @@ // llvm::DataLayout constructor. if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName()) return success(); - if (auto stringAttr = attr.getValue().dyn_cast()) + if (auto stringAttr = llvm::dyn_cast(attr.getValue())) return verifyDataLayoutString( stringAttr.getValue(), [op](const Twine &message) { op->emitOpError() << message.str(); }); @@ -2909,28 +2909,28 @@ StringAttr name = paramAttr.getName(); auto checkUnitAttrType = [&]() -> LogicalResult { - if (!paramAttr.getValue().isa()) + if (!llvm::isa(paramAttr.getValue())) return op->emitError() << name << " should be a unit attribute"; return success(); }; auto checkTypeAttrType = [&]() -> LogicalResult { - if (!paramAttr.getValue().isa()) + if (!llvm::isa(paramAttr.getValue())) return op->emitError() << name << " should be a type attribute"; return success(); }; auto checkIntegerAttrType = [&]() -> LogicalResult { - if (!paramAttr.getValue().isa()) + if (!llvm::isa(paramAttr.getValue())) return op->emitError() << name << " should be an integer attribute"; return success(); }; auto checkPointerType = [&]() -> LogicalResult { - if (!paramType.isa()) + if (!llvm::isa(paramType)) return op->emitError() << name << " attribute attached to non-pointer LLVM type"; return success(); }; auto checkIntegerType = [&]() -> LogicalResult { - if (!paramType.isa()) + if (!llvm::isa(paramType)) return op->emitError() << name << " attribute attached to non-integer LLVM type"; return success(); @@ -2938,8 +2938,8 @@ auto checkPointerTypeMatches = [&]() -> LogicalResult { if (failed(checkPointerType())) return failure(); - auto ptrType = paramType.cast(); - auto typeAttr = paramAttr.getValue().cast(); + auto ptrType = llvm::cast(paramType); + auto typeAttr = llvm::cast(paramAttr.getValue()); if (!ptrType.isOpaque() && ptrType.getElementType() != typeAttr.getValue()) return op->emitError() @@ -3033,7 +3033,7 @@ // Check to see if this function has a void return with a result attribute // to it. It isn't clear what semantics we would assign to that. - if (resType.isa()) + if (llvm::isa(resType)) return op->emitError() << "cannot attach result attributes to functions " "with a void return"; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp @@ -176,7 +176,7 @@ if (auto func = dyn_cast(parentOp)) { // Use the alignment attribute set for this argument in the parent function // if it has been set. - auto blockArg = value.cast(); + auto blockArg = llvm::cast(value); if (Attribute alignAttr = func.getArgAttr( blockArg.getArgNumber(), LLVM::LLVMDialect::getAlignAttrName())) return cast(alignAttr).getValue().getLimitedValue(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp @@ -29,7 +29,7 @@ // names processed here (e.g. 'tbaa'). This verification // is redundant in some cases. if (!llvm::all_of(symbolRefs, [](Attribute attr) { - return attr && attr.isa(); + return attr && llvm::isa(attr); })) return op->emitOpError() << name << " attribute failed to satisfy constraint: " diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -24,7 +24,8 @@ /// internal functions to avoid getting a verbose `!llvm` prefix. Otherwise /// prints it as usual. static void dispatchPrint(AsmPrinter &printer, Type type) { - if (isCompatibleType(type) && !type.isa()) + if (isCompatibleType(type) && + !llvm::isa(type)) return mlir::LLVM::detail::printType(type, printer); printer.printType(type); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -130,8 +130,9 @@ //===----------------------------------------------------------------------===// bool LLVMArrayType::isValidElementType(Type type) { - return !type.isa(); + return !llvm::isa( + type); } LLVMArrayType LLVMArrayType::get(Type elementType, unsigned numElements) { @@ -186,11 +187,11 @@ //===----------------------------------------------------------------------===// bool LLVMFunctionType::isValidArgumentType(Type type) { - return !type.isa(); + return !llvm::isa(type); } bool LLVMFunctionType::isValidResultType(Type type) { - return !type.isa(); + return !llvm::isa(type); } LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef arguments, @@ -239,9 +240,9 @@ if (!type) return true; return isCompatibleOuterType(type) - ? !type.isa() - : type.isa(); + ? !llvm::isa(type) + : llvm::isa(type); } LLVMPointerType LLVMPointerType::get(Type pointee, unsigned addressSpace) { @@ -266,7 +267,7 @@ std::optional mlir::LLVM::extractPointerSpecValue(Attribute attr, PtrDLEntryPos pos) { - auto spec = attr.cast(); + auto spec = llvm::cast(attr); auto idx = static_cast(pos); if (idx >= spec.size()) return std::nullopt; @@ -285,8 +286,8 @@ for (DataLayoutEntryInterface entry : params) { if (!entry.isTypeEntry()) continue; - if (entry.getKey().get().cast().getAddressSpace() == - type.getAddressSpace()) { + if (llvm::cast(entry.getKey().get()) + .getAddressSpace() == type.getAddressSpace()) { currentEntry = entry.getValue(); break; } @@ -350,11 +351,11 @@ continue; unsigned size = kDefaultPointerSizeBits; unsigned abi = kDefaultPointerAlignment; - auto newType = newEntry.getKey().get().cast(); + auto newType = llvm::cast(newEntry.getKey().get()); const auto *it = llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { if (auto type = entry.getKey().dyn_cast()) { - return type.cast().getAddressSpace() == + return llvm::cast(type).getAddressSpace() == newType.getAddressSpace(); } return false; @@ -362,7 +363,7 @@ if (it == oldLayout.end()) { llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { if (auto type = entry.getKey().dyn_cast()) { - return type.cast().getAddressSpace() == 0; + return llvm::cast(type).getAddressSpace() == 0; } return false; }); @@ -372,7 +373,7 @@ abi = *extractPointerSpecValue(*it, PtrDLEntryPos::Abi); } - Attribute newSpec = newEntry.getValue().cast(); + Attribute newSpec = llvm::cast(newEntry.getValue()); unsigned newSize = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Size); unsigned newAbi = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Abi); if (size != newSize || abi < newAbi || abi % newAbi != 0) @@ -386,8 +387,8 @@ for (DataLayoutEntryInterface entry : entries) { if (!entry.isTypeEntry()) continue; - auto key = entry.getKey().get().cast(); - auto values = entry.getValue().dyn_cast(); + auto key = llvm::cast(entry.getKey().get()); + auto values = llvm::dyn_cast(entry.getValue()); if (!values || (values.size() != 3 && values.size() != 4)) { return emitError(loc) << "expected layout attribute for " << entry.getKey().get() @@ -412,8 +413,9 @@ //===----------------------------------------------------------------------===// bool LLVMStructType::isValidElementType(Type type) { - return !type.isa(); + return !llvm::isa( + type); } LLVMStructType LLVMStructType::getIdentified(MLIRContext *context, @@ -538,7 +540,7 @@ if (currentEntry == params.end()) return std::nullopt; - auto attr = currentEntry->getValue().cast(); + auto attr = llvm::cast(currentEntry->getValue()); if (pos == StructDLEntryPos::Preferred && attr.size() <= static_cast(StructDLEntryPos::Preferred)) // If no preferred was specified, fall back to abi alignment @@ -586,7 +588,7 @@ } static unsigned extractStructSpecValue(Attribute attr, StructDLEntryPos pos) { - return attr.cast() + return llvm::cast(attr) .getValues()[static_cast(pos)]; } @@ -619,8 +621,8 @@ if (!entry.isTypeEntry()) continue; - auto key = entry.getKey().get().cast(); - auto values = entry.getValue().dyn_cast(); + auto key = llvm::cast(entry.getKey().get()); + auto values = llvm::dyn_cast(entry.getValue()); if (!values || (values.size() != 2 && values.size() != 1)) { return emitError(loc) << "expected layout attribute for " << entry.getKey().get() @@ -676,7 +678,7 @@ } bool LLVMFixedVectorType::isValidElementType(Type type) { - return type.isa(); + return llvm::isa(type); } LogicalResult @@ -705,10 +707,11 @@ } bool LLVMScalableVectorType::isValidElementType(Type type) { - if (auto intType = type.dyn_cast()) + if (auto intType = llvm::dyn_cast(type)) return intType.isSignless(); - return isCompatibleFloatingPointType(type) || type.isa(); + return isCompatibleFloatingPointType(type) || + llvm::isa(type); } LogicalResult @@ -724,7 +727,7 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) { // clang-format off - if (type.isa< + if (llvm::isa< BFloat16Type, Float16Type, Float32Type, @@ -745,17 +748,17 @@ LLVMScalableVectorType, LLVMVoidType, LLVMX86MMXType - >()) { + >(type)) { // clang-format on return true; } // Only signless integers are compatible. - if (auto intType = type.dyn_cast()) + if (auto intType = llvm::dyn_cast(type)) return intType.isSignless(); // 1D vector types are compatible. - if (auto vecType = type.dyn_cast()) + if (auto vecType = llvm::dyn_cast(type)) return vecType.getRank() == 1; return false; @@ -835,22 +838,22 @@ } bool mlir::LLVM::isCompatibleFloatingPointType(Type type) { - return type.isa(); + return llvm::isa(type); } bool mlir::LLVM::isCompatibleVectorType(Type type) { - if (type.isa()) + if (llvm::isa(type)) return true; - if (auto vecType = type.dyn_cast()) { + if (auto vecType = llvm::dyn_cast(type)) { if (vecType.getRank() != 1) return false; Type elementType = vecType.getElementType(); - if (auto intType = elementType.dyn_cast()) + if (auto intType = llvm::dyn_cast(elementType)) return intType.isSignless(); - return elementType.isa(); + return llvm::isa(elementType); } return false; } @@ -883,13 +886,12 @@ } bool mlir::LLVM::isScalableVectorType(Type vectorType) { - assert( - (vectorType - .isa()) && - "expected LLVM-compatible vector type"); - return !vectorType.isa() && - (vectorType.isa() || - vectorType.cast().isScalable()); + assert((llvm::isa( + vectorType)) && + "expected LLVM-compatible vector type"); + return !llvm::isa(vectorType) && + (llvm::isa(vectorType) || + llvm::cast(vectorType).isScalable()); } Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements, @@ -970,9 +972,9 @@ elementSize.isScalable()); }) .Default([](Type ty) { - assert((ty.isa()) && + assert((llvm::isa(ty)) && "unexpected missing support for primitive type"); return llvm::TypeSize::Fixed(0); }); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -90,13 +90,13 @@ return NVVM::MMATypes::f32; if (operandElType.isF32() && !isAccumulator) return NVVM::MMATypes::tf32; - if (operandElType.isa()) { + if (llvm::isa(operandElType)) { if (isAccumulator) return NVVM::MMATypes::s32; return std::nullopt; } - if (auto structType = operandElType.dyn_cast()) { + if (auto structType = llvm::dyn_cast(operandElType)) { if (structType.getBody().empty()) return std::nullopt; return inferOperandMMAType(structType.getBody()[0], isAccumulator); @@ -526,9 +526,9 @@ LogicalResult ShflOp::verify() { if (!(*this)->getAttrOfType("return_value_and_is_valid")) return success(); - auto type = getType().dyn_cast(); + auto type = llvm::dyn_cast(getType()); auto elementType = (type && type.getBody().size() == 2) - ? type.getBody()[1].dyn_cast() + ? llvm::dyn_cast(type.getBody()[1]) : nullptr; if (!elementType || elementType.getWidth() != 1) return emitError("expected return type to be a two-element struct with " @@ -600,7 +600,7 @@ LogicalResult NVVM::WMMALoadOp::verify() { unsigned addressSpace = - getPtr().getType().cast().getAddressSpace(); + llvm::cast(getPtr().getType()).getAddressSpace(); if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3) return emitOpError("expected source pointer in memory " "space 0, 1, 3"); @@ -620,7 +620,7 @@ LogicalResult NVVM::WMMAStoreOp::verify() { unsigned addressSpace = - getPtr().getType().cast().getAddressSpace(); + llvm::cast(getPtr().getType()).getAddressSpace(); if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3) return emitOpError("expected operands to be a source pointer in memory " "space 0, 1, 3"); @@ -672,7 +672,7 @@ LogicalResult NVVM::LdMatrixOp::verify() { unsigned addressSpace = - getPtr().getType().cast().getAddressSpace(); + llvm::cast(getPtr().getType()).getAddressSpace(); if (addressSpace != 3) return emitOpError("expected source pointer in memory space 3"); @@ -725,13 +725,13 @@ // If maxntid and reqntid exist, it must be an array with max 3 dim if (attrName == NVVMDialect::getMaxntidAttrName() || attrName == NVVMDialect::getReqntidAttrName()) { - auto values = attr.getValue().dyn_cast(); + auto values = llvm::dyn_cast(attr.getValue()); if (!values || values.empty() || values.size() > 3) return op->emitError() << "'" << attrName << "' attribute must be integer array with maximum 3 index"; - for (auto val : attr.getValue().cast()) { - if (!val.dyn_cast()) + for (auto val : llvm::cast(attr.getValue())) { + if (!llvm::dyn_cast(val)) return op->emitError() << "'" << attrName << "' attribute must be integer array with maximum 3 index"; @@ -740,7 +740,7 @@ // If minctasm and maxnreg exist, it must be an array with max 3 dim if (attrName == NVVMDialect::getMinctasmAttrName() || attrName == NVVMDialect::getMaxnregAttrName()) { - if (!attr.getValue().dyn_cast()) + if (!llvm::dyn_cast(attr.getValue())) return op->emitError() << "'" << attrName << "' attribute must be integer constant"; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -506,15 +506,15 @@ /// the type of `source`. static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) { - if (source.getType().isa()) + if (llvm::isa(source.getType())) return b.createOrFold(loc, source, dim); - if (source.getType().isa()) + if (llvm::isa(source.getType())) return b.createOrFold(loc, source, dim); llvm_unreachable("Expected MemRefType or TensorType"); } static OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) { - auto shapedType = source.getType().cast(); + auto shapedType = llvm::cast(source.getType()); if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) return createOrFoldDimOp(b, loc, source, dim); return b.getIndexAttr(shapedType.getDimSize(dim)); @@ -644,7 +644,7 @@ for (OpOperand *opOperand : getDpsInitOperands()) { SmallVector shapes; for (int64_t dim : llvm::seq(0, getRank(opOperand))) { - auto shapedType = opOperand->get().getType().cast(); + auto shapedType = llvm::cast(opOperand->get().getType()); if (!shapedType.isDynamicDim(dim)) { // Static dim: Return IntegerAttr. shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim))); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -63,7 +63,8 @@ TypeRange inputTypes, TypeRange outputTypes, ArrayRef attrs, RegionBuilderFn regionBuilder) { - assert(llvm::all_of(outputTypes, [](Type t) { return t.isa(); })); + assert(llvm::all_of(outputTypes, + [](Type t) { return llvm::isa(t); })); // TODO: atm all operands go through getElementTypeOrSelf, // reconsider when we have evidence we need to. @@ -106,7 +107,7 @@ resultTensorTypes.value_or(TypeRange()); if (!resultTensorTypes) copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes), - [](Type type) { return type.isa(); }); + [](Type type) { return llvm::isa(type); }); state.addOperands(inputs); state.addOperands(outputs); @@ -173,7 +174,7 @@ // Otherwise we append it to the discardable attributes dictionary where it is // handled by the generic Operation::create(...) method. if (result.propertiesAttr) { - NamedAttrList attrs = result.propertiesAttr.cast(); + NamedAttrList attrs = llvm::cast(result.propertiesAttr); attrs.append("operand_segment_sizes", parser.getBuilder().getDenseI32ArrayAttr( {static_cast(inputsOperands.size()), @@ -448,9 +449,15 @@ return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast); } - bool isComplex(Value value) { return value.getType().isa(); } - bool isFloatingPoint(Value value) { return value.getType().isa(); } - bool isInteger(Value value) { return value.getType().isa(); } + bool isComplex(Value value) { + return llvm::isa(value.getType()); + } + bool isFloatingPoint(Value value) { + return llvm::isa(value.getType()); + } + bool isInteger(Value value) { + return llvm::isa(value.getType()); + } OpBuilder getBuilder() { OpBuilder builder(context); @@ -748,8 +755,7 @@ for (auto attr : (*this)->getAttrs()) { if (attr.getName() == getIteratorTypesAttrName()) { auto iteratorTypes = - attr.getValue() - .cast() + llvm::cast(attr.getValue()) .getAsValueRange(); // Convert IteratorType enums into the string representation. This is // needed, because tests still use the old format when 'iterator_types' @@ -873,13 +879,13 @@ ValueRange results, const OpOperandVector &inputOperands, const OpOperandVector &outputOperands) { for (auto *operand : inputOperands) { - if (!operand->get().getType().isa()) + if (!llvm::isa(operand->get().getType())) continue; effects.emplace_back(MemoryEffects::Read::get(), operand->get(), SideEffects::DefaultResource::get()); } for (auto *operand : outputOperands) { - if (!operand->get().getType().isa()) + if (!llvm::isa(operand->get().getType())) continue; effects.emplace_back(MemoryEffects::Read::get(), operand->get(), SideEffects::DefaultResource::get()); @@ -942,7 +948,7 @@ // number to use for replacing uses of this operation. SmallVector returnedArgs; for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) { - auto yieldArg = yieldVal.value().dyn_cast(); + auto yieldArg = llvm::dyn_cast(yieldVal.value()); if (!yieldArg || yieldArg.getOwner() != &body) return failure(); unsigned argumentNumber = yieldArg.getArgNumber(); @@ -1003,7 +1009,7 @@ // Add result types. for (Type outputType : outputTypes) { - if (outputType.isa()) + if (llvm::isa(outputType)) result.addTypes(outputType); } @@ -1037,7 +1043,7 @@ // Add output types for `RankedTensorType` output arguments. Type initType = init.getType(); - if (initType.isa()) + if (llvm::isa(initType)) result.addTypes(initType); if (bodyBuild) @@ -1056,8 +1062,9 @@ b.setInsertionPointToStart(&block); SmallVector bbArgs; for (auto &operand : operands) { - block.addArgument(operand.getType().cast().getElementType(), - b.getUnknownLoc()); + block.addArgument( + llvm::cast(operand.getType()).getElementType(), + b.getUnknownLoc()); } SmallVector payloadOpOperands; // If initFirst flag is enabled, we consider init as the first position of @@ -1074,8 +1081,8 @@ Operation *payloadOp = b.create( result.location, b.getStringAttr(payloadOpName.getStringRef()), payloadOpOperands, - TypeRange{ - result.operands.back().getType().cast().getElementType()}, + TypeRange{llvm::cast(result.operands.back().getType()) + .getElementType()}, payloadOpAttrs); b.create(result.location, payloadOp->getResults()); } @@ -1151,7 +1158,8 @@ std::string attrToElide; p << " { " << payloadOp->getName().getStringRef(); for (const auto &attr : payloadOp->getAttrs()) { - auto fastAttr = attr.getValue().dyn_cast(); + auto fastAttr = + llvm::dyn_cast(attr.getValue()); if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) { attrToElide = attr.getName().str(); elidedAttrs.push_back(attrToElide); @@ -1200,7 +1208,8 @@ // The parameters of mapper should all match the element type of inputs. for (const auto &[bbArgType, inputArg] : llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) { - auto inputElemType = inputArg.getType().cast().getElementType(); + auto inputElemType = + llvm::cast(inputArg.getType()).getElementType(); if (bbArgType != inputElemType) { return emitOpError() << "expected element type of input " << inputElemType << " to match bbArg type " << bbArgType; @@ -1210,7 +1219,7 @@ // The shape of each input must match the shape of the output. auto outputShape = getInit().getType().getShape(); for (Type inputArgType : TypeRange{getInputs()}) { - auto inputElemShape = inputArgType.cast().getShape(); + auto inputElemShape = llvm::cast(inputArgType).getShape(); if (inputElemShape != outputShape) { return emitOpError() << "expected shape of input (" << inputElemShape << ") to match shape of output (" << outputShape @@ -1270,7 +1279,7 @@ // Add output types for `RankedTensorType` output arguments. for (Value init : inits) { Type initType = init.getType(); - if (initType.isa()) + if (llvm::isa(initType)) result.addTypes(initType); } @@ -1280,7 +1289,8 @@ } SmallVector ReduceOp::getIteratorTypesArray() { - int64_t inputRank = getInputs()[0].getType().cast().getRank(); + int64_t inputRank = + llvm::cast(getInputs()[0].getType()).getRank(); SmallVector iteratorTypes(inputRank, utils::IteratorType::parallel); for (int64_t reductionDim : getDimensions()) @@ -1289,7 +1299,8 @@ } ArrayAttr ReduceOp::getIndexingMaps() { - int64_t inputRank = getInputs()[0].getType().cast().getRank(); + int64_t inputRank = + llvm::cast(getInputs()[0].getType()).getRank(); SmallVector affineMaps( getNumDpsInputs(), AffineMap::getMultiDimIdentityMap(inputRank, getContext())); @@ -1390,8 +1401,8 @@ ArrayRef dimensionsRef = getDimensions(); for (int64_t i = 1; i < getNumDpsInputs(); ++i) { - if (getInputs()[i].getType().cast().getShape() != - getInputs()[0].getType().cast().getShape()) { + if (llvm::cast(getInputs()[i].getType()).getShape() != + llvm::cast(getInputs()[0].getType()).getShape()) { return emitOpError() << "expects all inputs to have the same shapes. " "Shape at input-index " << i @@ -1399,16 +1410,16 @@ } } for (int64_t i = 1; i < getNumDpsInits(); ++i) { - if (getInits()[i].getType().cast().getShape() != - getInits()[0].getType().cast().getShape()) { + if (llvm::cast(getInits()[i].getType()).getShape() != + llvm::cast(getInits()[0].getType()).getShape()) { return emitOpError() << "expects all outputs to have the same shapes. " "Shape at output-index " << i << " is not equal to the shape at output-index 0."; } } - auto inputType = getInputs()[0].getType().cast(); - auto initType = getInits()[0].getType().cast(); + auto inputType = llvm::cast(getInputs()[0].getType()); + auto initType = llvm::cast(getInits()[0].getType()); DenseSet dimensionsToReduce; for (int64_t dimension : dimensionsRef) { @@ -1449,7 +1460,8 @@ // Check that the first block arguments match the element type of the inputs. for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) { - Type inputElementType = input.getType().cast().getElementType(); + Type inputElementType = + llvm::cast(input.getType()).getElementType(); if (inputElementType != bbArg.getType()) return emitOpError() << "input element type " << inputElementType @@ -1462,7 +1474,7 @@ llvm::zip(getDpsInitOperands(), block->getArguments().take_back(getNumDpsInits()))) { auto outputElementType = - output->get().getType().cast().getElementType(); + llvm::cast(output->get().getType()).getElementType(); if (outputElementType != bbArg.getType()) return emitOpError() << "output element type " << outputElementType @@ -1496,7 +1508,7 @@ // Add output types for `RankedTensorType` output arguments. Type initType = init.getType(); - if (initType.isa()) + if (llvm::isa(initType)) result.addTypes(initType); buildIdentityRegion(builder, result.location, *result.addRegion(), input, @@ -1610,7 +1622,7 @@ // Add output types for `RankedTensorType` output arguments. Type initType = init.getType(); - if (initType.isa()) + if (llvm::isa(initType)) result.addTypes(initType); buildIdentityRegion(builder, result.location, *result.addRegion(), input, @@ -1828,7 +1840,7 @@ } static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) { - if (auto memref = t.dyn_cast()) { + if (auto memref = llvm::dyn_cast(t)) { ss << "view"; for (auto size : memref.getShape()) if (size < 0) @@ -1838,14 +1850,14 @@ if (failed(appendMangledType(ss, memref.getElementType()))) return failure(); if (auto as = memref.getMemorySpace()) { - if (auto attr = as.dyn_cast()) + if (auto attr = llvm::dyn_cast(as)) ss << "as" << attr.getInt(); else return failure(); } return success(); } - if (auto vec = t.dyn_cast()) { + if (auto vec = llvm::dyn_cast(t)) { ss << "vector"; llvm::interleave( vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; }); @@ -1864,9 +1876,9 @@ std::string name(op->getName().getStringRef().str()); std::string fun = ""; for (NamedAttribute kv : op->getAttrs()) { - if (UnaryFnAttr ufa = kv.getValue().dyn_cast()) { + if (UnaryFnAttr ufa = llvm::dyn_cast(kv.getValue())) { fun = stringifyEnum(ufa.getValue()).str() + "_"; - } else if (BinaryFnAttr bfa = kv.getValue().dyn_cast()) { + } else if (BinaryFnAttr bfa = llvm::dyn_cast(kv.getValue())) { fun = stringifyEnum(bfa.getValue()).str() + "_"; } } @@ -1898,7 +1910,7 @@ // Linalg "inputs" may be either tensor or memref type. // tensor<0xelt_type> is a convention that may not always mean // "0 iterations". Only erase in cases we see memref<...x0x...>. - auto mt = opOperand.get().getType().dyn_cast(); + auto mt = llvm::dyn_cast(opOperand.get().getType()); if (!mt) continue; if (llvm::is_contained(op.getShape(&opOperand), 0)) { @@ -1934,9 +1946,10 @@ rewriter.setInsertionPoint(linalgOp); Location loc = linalgOp.getLoc(); - OpResult resultValue = castOp.getSource().cast(); + OpResult resultValue = llvm::cast(castOp.getSource()); unsigned resultNumber = resultValue.getResultNumber(); - auto resultType = castOp->getResult(0).getType().cast(); + auto resultType = + llvm::cast(castOp->getResult(0).getType()); // Replace the `outs` for the result with a `tensor.cast`. This cast is now // going from a more dynamic shape to a less dynamic shape. If the producer // for this cast, i.e. producer of the out operand, is also an operation @@ -1975,7 +1988,7 @@ if (linalgOp.isScalar(&opOperand)) continue; Value src = opOperand.get(); - auto sourceType = src.getType().cast(); + auto sourceType = llvm::cast(src.getType()); auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand); // Get the `sourceShape` of the `sourceType`. If the operand is a result of @@ -1986,7 +1999,8 @@ if (parentOp) { if (auto castOp = dyn_cast(parentOp)) { Value castSource = castOp.getSource(); - auto castSourceType = castSource.getType().dyn_cast(); + auto castSourceType = + llvm::dyn_cast(castSource.getType()); if (castSourceType && castSourceType.hasStaticShape()) sourceShape = castSourceType.getShape(); } @@ -2017,7 +2031,7 @@ newOperands.push_back(src); if (linalgOp.isScalar(opOperand)) return; - auto sourceType = src.getType().cast(); + auto sourceType = llvm::cast(src.getType()); Type resultType = sourceType; if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) { resultTypes.push_back(resultType); diff --git a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp @@ -37,7 +37,7 @@ int64_t flatDimCtr = 0; for (Value operand : linalgOp->getOperands()) { assert(flatDimPos >= flatDimCtr && "invalid pos"); - auto shapedType = operand.getType().cast(); + auto shapedType = llvm::cast(operand.getType()); if (flatDimPos < flatDimCtr + shapedType.getRank()) { cstr.bound(value) < cstr.getExpr(operand, flatDimPos - flatDimCtr); break; diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp --- a/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp +++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp @@ -28,7 +28,7 @@ using OpAsmDialectInterface::OpAsmDialectInterface; AliasResult getAlias(Attribute attr, raw_ostream &os) const override { - if (attr.isa()) { + if (llvm::isa(attr)) { os << "extern"; return AliasResult::OverridableAlias; } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp @@ -22,7 +22,7 @@ //===----------------------------------------------------------------------===// static bool isSupportedElementType(Type type) { - return type.isa() || + return llvm::isa(type) || OpBuilder(type.getContext()).getZeroAttr(type); } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -90,7 +90,7 @@ for (OpOperand &operand : op->getOpOperands()) { auto cast = operand.get().getDefiningOp(); if (cast && operand.get() != inner && - !cast.getOperand().getType().isa()) { + !llvm::isa(cast.getOperand().getType())) { operand.set(cast.getOperand()); folded = true; } @@ -101,16 +101,16 @@ /// Return an unranked/ranked tensor type for the given unranked/ranked memref /// type. Type mlir::memref::getTensorTypeFromMemRefType(Type type) { - if (auto memref = type.dyn_cast()) + if (auto memref = llvm::dyn_cast(type)) return RankedTensorType::get(memref.getShape(), memref.getElementType()); - if (auto memref = type.dyn_cast()) + if (auto memref = llvm::dyn_cast(type)) return UnrankedTensorType::get(memref.getElementType()); return NoneType::get(type.getContext()); } SmallVector memref::getMixedSizes(OpBuilder &builder, Location loc, Value value) { - auto memrefType = value.getType().cast(); + auto memrefType = llvm::cast(value.getType()); SmallVector result; for (int64_t i = 0; i < memrefType.getRank(); ++i) { if (memrefType.isDynamicDim(i)) { @@ -180,7 +180,7 @@ // values, hence we recreate the attribute even when it is already static // to make sure the type is consistent. ofr = builder.getIndexAttr( - ofr.get().cast().getInt()); + llvm::cast(ofr.get()).getInt()); continue; } std::optional maybeConstant = @@ -241,7 +241,7 @@ static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { static_assert(llvm::is_one_of::value, "applies to only alloc or alloca"); - auto memRefType = op.getResult().getType().template dyn_cast(); + auto memRefType = llvm::dyn_cast(op.getResult().getType()); if (!memRefType) return op.emitOpError("result must be a memref"); @@ -378,7 +378,7 @@ //===----------------------------------------------------------------------===// LogicalResult ReallocOp::verify() { - auto sourceType = getOperand(0).getType().cast(); + auto sourceType = llvm::cast(getOperand(0).getType()); MemRefType resultType = getType(); // The source memref should have identity layout (or none). @@ -691,8 +691,9 @@ /// consumer %0 ... : memref(16 * i + j)>> /// ``` bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { - MemRefType sourceType = castOp.getSource().getType().dyn_cast(); - MemRefType resultType = castOp.getType().dyn_cast(); + MemRefType sourceType = + llvm::dyn_cast(castOp.getSource().getType()); + MemRefType resultType = llvm::dyn_cast(castOp.getType()); // Requires ranked MemRefType. if (!sourceType || !resultType) @@ -743,11 +744,11 @@ if (inputs.size() != 1 || outputs.size() != 1) return false; Type a = inputs.front(), b = outputs.front(); - auto aT = a.dyn_cast(); - auto bT = b.dyn_cast(); + auto aT = llvm::dyn_cast(a); + auto bT = llvm::dyn_cast(b); - auto uaT = a.dyn_cast(); - auto ubT = b.dyn_cast(); + auto uaT = llvm::dyn_cast(a); + auto ubT = llvm::dyn_cast(b); if (aT && bT) { if (aT.getElementType() != bT.getElementType()) @@ -831,8 +832,8 @@ // Check source. if (auto castOp = copyOp.getSource().getDefiningOp()) { - auto fromType = castOp.getSource().getType().dyn_cast(); - auto toType = castOp.getSource().getType().dyn_cast(); + auto fromType = llvm::dyn_cast(castOp.getSource().getType()); + auto toType = llvm::dyn_cast(castOp.getSource().getType()); if (fromType && toType) { if (fromType.getShape() == toType.getShape() && @@ -847,8 +848,8 @@ // Check target. if (auto castOp = copyOp.getTarget().getDefiningOp()) { - auto fromType = castOp.getSource().getType().dyn_cast(); - auto toType = castOp.getSource().getType().dyn_cast(); + auto fromType = llvm::dyn_cast(castOp.getSource().getType()); + auto toType = llvm::dyn_cast(castOp.getSource().getType()); if (fromType && toType) { if (fromType.getShape() == toType.getShape() && @@ -970,7 +971,7 @@ for (const auto &dim : llvm::enumerate(sizes)) if (auto attr = dim.value().dyn_cast()) - if (attr.cast().getInt() == 1) + if (llvm::cast(attr).getInt() == 1) unusedDims.set(dim.index()); // Early exit for the case where the number of unused dims matches the number @@ -1046,7 +1047,7 @@ return {}; // Folding for unranked types (UnrankedMemRefType) is not supported. - auto memrefType = getSource().getType().dyn_cast(); + auto memrefType = llvm::dyn_cast(getSource().getType()); if (!memrefType) return {}; @@ -1256,7 +1257,7 @@ // Check types of operands. The order of these calls is important: the later // calls rely on some type properties to compute the operand position. // 1. Source memref. - if (!getSrcMemRef().getType().isa()) + if (!llvm::isa(getSrcMemRef().getType())) return emitOpError("expected source to be of memref type"); if (numOperands < getSrcMemRefRank() + 4) return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 @@ -1267,7 +1268,7 @@ return emitOpError("expected source indices to be of index type"); // 2. Destination memref. - if (!getDstMemRef().getType().isa()) + if (!llvm::isa(getDstMemRef().getType())) return emitOpError("expected destination to be of memref type"); unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; if (numOperands < numExpectedOperands) @@ -1283,7 +1284,7 @@ return emitOpError("expected num elements to be of index type"); // 4. Tag memref. - if (!getTagMemRef().getType().isa()) + if (!llvm::isa(getTagMemRef().getType())) return emitOpError("expected tag to be of memref type"); numExpectedOperands += getTagMemRefRank(); if (numOperands < numExpectedOperands) @@ -1359,7 +1360,8 @@ SmallVectorImpl &inferredReturnTypes) { ExtractStridedMetadataOpAdaptor extractAdaptor(operands, attributes, properties); - auto sourceType = extractAdaptor.getSource().getType().dyn_cast(); + auto sourceType = + llvm::dyn_cast(extractAdaptor.getSource().getType()); if (!sourceType) return failure(); @@ -1409,8 +1411,7 @@ "The constified value should be either unchanged (i.e., == result) " "or a constant"); Value constantVal = rewriter.create( - loc, maybeConstant.template get() - .template cast() + loc, llvm::cast(maybeConstant.template get()) .getInt()); for (Operation *op : llvm::make_early_inc_range(result.getUsers())) { // updateRootInplace: lambda cannot capture structured bindings in C++17 @@ -1470,7 +1471,7 @@ result.addOperands(memref); result.addOperands(ivs); - if (auto memrefType = memref.getType().dyn_cast()) { + if (auto memrefType = llvm::dyn_cast(memref.getType())) { Type elementType = memrefType.getElementType(); result.addTypes(elementType); @@ -1519,7 +1520,7 @@ if (parser.parseRegion(*body, {}) || parser.parseOptionalAttrDict(result.attributes)) return failure(); - result.types.push_back(memrefType.cast().getElementType()); + result.types.push_back(llvm::cast(memrefType).getElementType()); return success(); } @@ -1567,7 +1568,7 @@ if (parser.parseType(type)) return failure(); - auto memrefType = type.dyn_cast(); + auto memrefType = llvm::dyn_cast(type); if (!memrefType || !memrefType.hasStaticShape()) return parser.emitError(parser.getNameLoc()) << "type should be static shaped memref, but got " << type; @@ -1584,14 +1585,14 @@ Type tensorType = getTensorTypeFromMemRefType(memrefType); if (parser.parseAttribute(initialValue, tensorType)) return failure(); - if (!initialValue.isa()) + if (!llvm::isa(initialValue)) return parser.emitError(parser.getNameLoc()) << "initial value should be a unit or elements attribute"; return success(); } LogicalResult GlobalOp::verify() { - auto memrefType = getType().dyn_cast(); + auto memrefType = llvm::dyn_cast(getType()); if (!memrefType || !memrefType.hasStaticShape()) return emitOpError("type should be static shaped memref, but got ") << getType(); @@ -1600,14 +1601,14 @@ // an elements attribute. if (getInitialValue().has_value()) { Attribute initValue = getInitialValue().value(); - if (!initValue.isa() && !initValue.isa()) + if (!llvm::isa(initValue) && !llvm::isa(initValue)) return emitOpError("initial value should be a unit or elements " "attribute, but got ") << initValue; // Check that the type of the initial value is compatible with the type of // the global variable. - if (auto elementsAttr = initValue.dyn_cast()) { + if (auto elementsAttr = llvm::dyn_cast(initValue)) { Type initType = elementsAttr.getType(); Type tensorType = getTensorTypeFromMemRefType(memrefType); if (initType != tensorType) @@ -1631,7 +1632,7 @@ ElementsAttr GlobalOp::getConstantInitValue() { auto initVal = getInitialValue(); if (getConstant() && initVal.has_value()) - return initVal.value().cast(); + return llvm::cast(initVal.value()); return {}; } @@ -1687,11 +1688,11 @@ if (inputs.size() != 1 || outputs.size() != 1) return false; Type a = inputs.front(), b = outputs.front(); - auto aT = a.dyn_cast(); - auto bT = b.dyn_cast(); + auto aT = llvm::dyn_cast(a); + auto bT = llvm::dyn_cast(b); - auto uaT = a.dyn_cast(); - auto ubT = b.dyn_cast(); + auto uaT = llvm::dyn_cast(a); + auto ubT = llvm::dyn_cast(b); if (aT && bT) { if (aT.getElementType() != bT.getElementType()) @@ -1794,7 +1795,7 @@ OpFoldResult RankOp::fold(FoldAdaptor adaptor) { // Constant fold rank when the rank of the operand is known. auto type = getOperand().getType(); - auto shapedType = type.dyn_cast(); + auto shapedType = llvm::dyn_cast(type); if (shapedType && shapedType.hasRank()) return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); return IntegerAttr(); @@ -1861,8 +1862,8 @@ // completed automatically, like we have for subview and extract_slice. LogicalResult ReinterpretCastOp::verify() { // The source and result memrefs should be in the same memory space. - auto srcType = getSource().getType().cast(); - auto resultType = getType().cast(); + auto srcType = llvm::cast(getSource().getType()); + auto resultType = llvm::cast(getType()); if (srcType.getMemorySpace() != resultType.getMemorySpace()) return emitError("different memory spaces specified for source type ") << srcType << " and result memref type " << resultType; @@ -2250,7 +2251,7 @@ ArrayRef resultShape, Value src, ArrayRef reassociation) { // Only ranked memref source values are supported. - auto srcType = src.getType().cast(); + auto srcType = llvm::cast(src.getType()); FailureOr resultType = ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation); // Failure of this assertion usually indicates a problem with the source @@ -2406,7 +2407,7 @@ void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, ArrayRef attrs) { - auto srcType = src.getType().cast(); + auto srcType = llvm::cast(src.getType()); MemRefType resultType = CollapseShapeOp::computeCollapsedType(srcType, reassociation); build(b, result, resultType, src, attrs); @@ -2473,7 +2474,7 @@ return failure(); Type newResultType = CollapseShapeOp::computeCollapsedType( - cast.getOperand().getType().cast(), + llvm::cast(cast.getOperand().getType()), op.getReassociationIndices()); if (newResultType == op.getResultType()) { @@ -2518,18 +2519,20 @@ Type operandType = getSource().getType(); Type resultType = getResult().getType(); - Type operandElementType = operandType.cast().getElementType(); - Type resultElementType = resultType.cast().getElementType(); + Type operandElementType = + llvm::cast(operandType).getElementType(); + Type resultElementType = llvm::cast(resultType).getElementType(); if (operandElementType != resultElementType) return emitOpError("element types of source and destination memref " "types should be the same"); - if (auto operandMemRefType = operandType.dyn_cast()) + if (auto operandMemRefType = llvm::dyn_cast(operandType)) if (!operandMemRefType.getLayout().isIdentity()) return emitOpError("source memref type should have identity affine map"); - int64_t shapeSize = getShape().getType().cast().getDimSize(0); - auto resultMemRefType = resultType.dyn_cast(); + int64_t shapeSize = + llvm::cast(getShape().getType()).getDimSize(0); + auto resultMemRefType = llvm::dyn_cast(resultType); if (resultMemRefType) { if (!resultMemRefType.getLayout().isIdentity()) return emitOpError("result memref type should have identity affine map"); @@ -2634,9 +2637,8 @@ ArrayRef offsets, ArrayRef sizes, ArrayRef strides) { - auto inferredType = - inferResultType(sourceRankedTensorType, offsets, sizes, strides) - .cast(); + auto inferredType = llvm::cast( + inferResultType(sourceRankedTensorType, offsets, sizes, strides)); assert(inferredType.getRank() >= static_cast(resultShape.size()) && "expected "); if (inferredType.getRank() == static_cast(resultShape.size())) @@ -2648,7 +2650,7 @@ assert(dimsToProject.has_value() && "invalid rank reduction"); // Compute the layout and result type. - auto inferredLayout = inferredType.getLayout().cast(); + auto inferredLayout = llvm::cast(inferredType.getLayout()); SmallVector rankReducedStrides; rankReducedStrides.reserve(resultShape.size()); for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) { @@ -2690,12 +2692,11 @@ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - auto sourceMemRefType = source.getType().cast(); + auto sourceMemRefType = llvm::cast(source.getType()); // Structuring implementation this way avoids duplication between builders. if (!resultType) { - resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, - staticSizes, staticStrides) - .cast(); + resultType = llvm::cast(SubViewOp::inferResultType( + sourceMemRefType, staticOffsets, staticSizes, staticStrides)); } build(b, result, resultType, source, dynamicOffsets, dynamicSizes, dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets), @@ -2824,7 +2825,7 @@ template static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, OpTy op, Type expectedType) { - auto memrefType = expectedType.cast(); + auto memrefType = llvm::cast(expectedType); switch (result) { case SliceVerificationResult::Success: return success(); @@ -2867,7 +2868,7 @@ auto expectedType = SubViewOp::inferResultType( baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()); - auto result = isRankReducedMemRefType(expectedType.cast(), + auto result = isRankReducedMemRefType(llvm::cast(expectedType), subViewType, getMixedSizes()); return produceSubViewErrorMsg(result, *this, expectedType); } @@ -2917,9 +2918,8 @@ MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef mixedOffsets, ArrayRef mixedSizes, ArrayRef mixedStrides) { - auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets, - mixedSizes, mixedStrides) - .cast(); + auto nonRankReducedType = llvm::cast(SubViewOp::inferResultType( + sourceType, mixedOffsets, mixedSizes, mixedStrides)); std::optional unusedDims = computeMemRefRankReductionMask(currentSourceType, currentResultType, mixedSizes); @@ -2927,7 +2927,7 @@ if (!unusedDims) return nullptr; - auto layout = nonRankReducedType.getLayout().cast(); + auto layout = llvm::cast(nonRankReducedType.getLayout()); SmallVector shape, strides; unsigned numDimsAfterReduction = nonRankReducedType.getRank() - unusedDims->count(); @@ -2962,14 +2962,14 @@ Value mlir::memref::createCanonicalRankReducingSubViewOp( OpBuilder &b, Location loc, Value memref, ArrayRef targetShape) { - auto memrefType = memref.getType().cast(); + auto memrefType = llvm::cast(memref.getType()); unsigned rank = memrefType.getRank(); SmallVector offsets(rank, b.getIndexAttr(0)); SmallVector sizes = getMixedSizes(b, loc, memref); SmallVector strides(rank, b.getIndexAttr(1)); - auto targetType = SubViewOp::inferRankReducedResultType( - targetShape, memrefType, offsets, sizes, strides) - .cast(); + auto targetType = + llvm::cast(SubViewOp::inferRankReducedResultType( + targetShape, memrefType, offsets, sizes, strides)); return b.createOrFold(loc, targetType, memref, offsets, sizes, strides); } @@ -2977,7 +2977,7 @@ FailureOr SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value, ArrayRef desiredShape) { - auto sourceMemrefType = value.getType().dyn_cast(); + auto sourceMemrefType = llvm::dyn_cast(value.getType()); assert(sourceMemrefType && "not a ranked memref type"); auto sourceShape = sourceMemrefType.getShape(); if (sourceShape.equals(desiredShape)) @@ -3069,7 +3069,7 @@ // if the operation is rank-reducing. auto resultType = getCanonicalSubViewResultType( subViewOp.getType(), subViewOp.getSourceType(), - castOp.getSource().getType().cast(), + llvm::cast(castOp.getSource().getType()), subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), subViewOp.getMixedStrides()); if (!resultType) @@ -3134,8 +3134,8 @@ } OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) { - auto resultShapedType = getResult().getType().cast(); - auto sourceShapedType = getSource().getType().cast(); + auto resultShapedType = llvm::cast(getResult().getType()); + auto sourceShapedType = llvm::cast(getSource().getType()); if (resultShapedType.hasStaticShape() && resultShapedType == sourceShapedType) { @@ -3201,7 +3201,7 @@ auto permutationMap = permutation.getValue(); assert(permutationMap); - auto memRefType = in.getType().cast(); + auto memRefType = llvm::cast(in.getType()); // Compute result type. MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); @@ -3239,8 +3239,8 @@ if (getPermutation().getNumDims() != getIn().getType().getRank()) return emitOpError("expected a permutation map of same rank as the input"); - auto srcType = getIn().getType().cast(); - auto dstType = getType().cast(); + auto srcType = llvm::cast(getIn().getType()); + auto dstType = llvm::cast(getType()); auto transposedType = inferTransposeResultType(srcType, getPermutation()); if (dstType != transposedType) return emitOpError("output type ") @@ -3264,7 +3264,7 @@ } LogicalResult ViewOp::verify() { - auto baseType = getOperand(0).getType().cast(); + auto baseType = llvm::cast(getOperand(0).getType()); auto viewType = getType(); // The base memref should have identity layout map (or none). @@ -3401,7 +3401,7 @@ case arith::AtomicRMWKind::maxf: case arith::AtomicRMWKind::minf: case arith::AtomicRMWKind::mulf: - if (!getValue().getType().isa()) + if (!llvm::isa(getValue().getType())) return emitOpError() << "with kind '" << arith::stringifyAtomicRMWKind(getKind()) << "' expects a floating-point type"; @@ -3414,7 +3414,7 @@ case arith::AtomicRMWKind::muli: case arith::AtomicRMWKind::ori: case arith::AtomicRMWKind::andi: - if (!getValue().getType().isa()) + if (!llvm::isa(getValue().getType())) return emitOpError() << "with kind '" << arith::stringifyAtomicRMWKind(getKind()) << "' expects an integer type"; diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp @@ -37,8 +37,8 @@ auto castOp = cast(op); assert(value == castOp.getResult() && "invalid value"); - if (castOp.getResult().getType().isa() && - castOp.getSource().getType().isa()) { + if (llvm::isa(castOp.getResult().getType()) && + llvm::isa(castOp.getSource().getType())) { cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim); } } @@ -79,7 +79,7 @@ auto rankOp = cast(op); assert(value == rankOp.getResult() && "invalid value"); - auto memrefType = rankOp.getMemref().getType().dyn_cast(); + auto memrefType = llvm::dyn_cast(rankOp.getMemref().getType()); if (!memrefType) return; cstr.bound(value) == memrefType.getRank(); diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -38,9 +38,9 @@ Attribute memorySpace = type.getMemorySpace(); if (!memorySpace) return false; - if (auto intAttr = memorySpace.dyn_cast()) + if (auto intAttr = llvm::dyn_cast(memorySpace)) return intAttr.getInt() == NVGPUDialect::kSharedMemoryAddressSpace; - if (auto gpuAttr = memorySpace.dyn_cast()) + if (auto gpuAttr = llvm::dyn_cast(memorySpace)) return gpuAttr.getValue() == gpu::AddressSpace::Workgroup; return false; } @@ -61,8 +61,8 @@ } LogicalResult DeviceAsyncCopyOp::verify() { - auto srcMemref = getSrc().getType().cast(); - auto dstMemref = getDst().getType().cast(); + auto srcMemref = llvm::cast(getSrc().getType()); + auto dstMemref = llvm::cast(getDst().getType()); if (!isLastMemrefDimUnitStride(srcMemref)) return emitError("source memref most minor dim must have unit stride"); @@ -246,10 +246,10 @@ LogicalResult LdMatrixOp::verify() { // ldmatrix reads data from source in shared memory - auto srcMemref = getSrcMemref().getType().cast(); + auto srcMemref = llvm::cast(getSrcMemref().getType()); // ldmatrix writes data to result/destination in vector registers - auto resVector = getRes().getType().cast(); + auto resVector = llvm::cast(getRes().getType()); // vector register shape, element type, and bitwidth ArrayRef resShape = resVector.getShape(); diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -41,7 +41,7 @@ struct PointerLikeModel : public PointerLikeType::ExternalModel, T> { Type getElementType(Type pointer) const { - return pointer.cast().getElementType(); + return llvm::cast(pointer).getElementType(); } }; @@ -231,7 +231,7 @@ // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section for (unsigned i = 0; i < (*alignmentValues).size(); ++i) { - if (auto intAttr = (*alignmentValues)[i].dyn_cast()) { + if (auto intAttr = llvm::dyn_cast((*alignmentValues)[i])) { if (intAttr.getValue().sle(0)) return op->emitOpError() << "alignment should be greater than 0"; } else { @@ -463,7 +463,7 @@ return op->emitOpError() << "accumulator variable used more than once"; Type varType = accum.getType(); - auto symbolRef = std::get<1>(args).cast(); + auto symbolRef = llvm::cast(std::get<1>(args)); auto decl = SymbolTable::lookupNearestSymbolFrom(op, symbolRef); if (!decl) @@ -521,7 +521,8 @@ if (i != 0) p << ", "; p << stringifyClauseTaskDepend( - (*depends)[i].cast().getValue()) + llvm::cast((*depends)[i]) + .getValue()) << " -> " << dependVars[i] << " : " << dependTypes[i]; } } @@ -723,8 +724,8 @@ Value mapOp = map_operands[i]; Attribute mapTypeOp = map_types[i]; - assert(mapTypeOp.isa()); - mapTypeBits = mapTypeOp.cast().getInt(); + assert(llvm::isa(mapTypeOp)); + mapTypeBits = llvm::cast(mapTypeOp).getInt(); bool always = bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS); @@ -1018,8 +1019,8 @@ atomicReductionEntryBlock.getArgumentTypes()[1]) return emitOpError() << "expects atomic reduction region with two " "arguments of the same type"; - auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] - .dyn_cast(); + auto ptrType = llvm::dyn_cast( + atomicReductionEntryBlock.getArgumentTypes()[0]); if (!ptrType || (ptrType.getElementType() && ptrType.getElementType() != getType())) return emitOpError() << "expects atomic reduction region arguments to " @@ -1210,7 +1211,7 @@ } } Type elementType = - getAddress().getType().cast().getElementType(); + llvm::cast(getAddress().getType()).getElementType(); if (elementType && elementType != getValue().getType()) return emitError("address must dereference to value type"); return verifySynchronizationHint(*this, getHintVal()); @@ -1261,7 +1262,8 @@ if (getRegion().getNumArguments() != 1) return emitError("the region must accept exactly one argument"); - Type elementType = getX().getType().cast().getElementType(); + Type elementType = + llvm::cast(getX().getType()).getElementType(); if (elementType && elementType != getRegion().getArgument(0).getType()) { return emitError("the type of the operand must be a pointer type whose " "element type is the same as that of the region argument"); diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -465,7 +465,7 @@ } LogicalResult ResultsOp::verify() { - if (!getIndex() && getType().isa()) { + if (!getIndex() && llvm::isa(getType())) { return emitOpError() << "expected `pdl.range` result type when " "no index is specified, but got: " << getType(); diff --git a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp --- a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp @@ -60,7 +60,7 @@ } Type pdl::getRangeElementTypeOrSelf(Type type) { - if (auto rangeType = type.dyn_cast()) + if (auto rangeType = llvm::dyn_cast(type)) return rangeType.getElementType(); return type; } @@ -78,7 +78,7 @@ if (!elementType || parser.parseGreater()) return Type(); - if (elementType.isa()) { + if (llvm::isa(elementType)) { parser.emitError(elementLoc) << "element of pdl.range cannot be another range, but got" << elementType; @@ -95,7 +95,7 @@ LogicalResult RangeType::verify(function_ref emitError, Type elementType) { - if (!elementType.isa() || elementType.isa()) { + if (!llvm::isa(elementType) || llvm::isa(elementType)) { return emitError() << "expected element of pdl.range to be one of [!pdl.attribute, " "!pdl.operation, !pdl.type, !pdl.value], but got " diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp --- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp +++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp @@ -145,7 +145,7 @@ if (initLoop) { // Create the block and the loop variable. // FIXME: Allow passing in a proper location for the loop variable. - auto rangeType = range.getType().cast(); + auto rangeType = llvm::cast(range.getType()); state.regions.front()->emplaceBlock(); state.regions.front()->addArgument(rangeType.getElementType(), state.location); @@ -238,7 +238,8 @@ /// Given the result type of a `GetValueTypeOp`, return the expected input type. static Type getGetValueTypeOpValueType(Type type) { Type valueTy = pdl::ValueType::get(type.getContext()); - return type.isa() ? pdl::RangeType::get(valueTy) : valueTy; + return llvm::isa(type) ? pdl::RangeType::get(valueTy) + : valueTy; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -35,7 +35,7 @@ // Verify that the storage type is integral. // This restriction may be lifted at some point in favor of using bf16 // or f16 as exact representations on hardware where that is advantageous. - auto intStorageType = storageType.dyn_cast(); + auto intStorageType = llvm::dyn_cast(storageType); if (!intStorageType) return emitError() << "storage type must be integral"; unsigned integralWidth = intStorageType.getWidth(); @@ -83,8 +83,8 @@ } bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) { - if (candidateExpressedType.isa()) { - return candidateExpressedType.cast().getElementType() == + if (llvm::isa(candidateExpressedType)) { + return llvm::cast(candidateExpressedType).getElementType() == getExpressedType(); } return candidateExpressedType == getExpressedType(); @@ -92,12 +92,12 @@ QuantizedType QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) { - if (primitiveOrContainerType.isa()) { + if (llvm::isa(primitiveOrContainerType)) { Type elementType = - primitiveOrContainerType.cast().getElementType(); - return elementType.dyn_cast(); + llvm::cast(primitiveOrContainerType).getElementType(); + return llvm::dyn_cast(elementType); } - return primitiveOrContainerType.dyn_cast(); + return llvm::dyn_cast(primitiveOrContainerType); } Type QuantizedType::castFromStorageType(Type candidateType) { @@ -105,18 +105,19 @@ // i.e. i32 -> quant<"uniform[i8:f32]{1.0}"> return *this; } - if (candidateType.isa()) { + if (llvm::isa(candidateType)) { // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> return RankedTensorType::get( - candidateType.cast().getShape(), getStorageType()); + llvm::cast(candidateType).getShape(), + getStorageType()); } - if (candidateType.isa()) { + if (llvm::isa(candidateType)) { // i.e. tensor -> tensor> return UnrankedTensorType::get(getStorageType()); } - if (candidateType.isa()) { + if (llvm::isa(candidateType)) { // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - return VectorType::get(candidateType.cast().getShape(), + return VectorType::get(llvm::cast(candidateType).getShape(), getStorageType()); } @@ -124,25 +125,25 @@ } Type QuantizedType::castToStorageType(Type quantizedType) { - if (quantizedType.isa()) { + if (llvm::isa(quantizedType)) { // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8 - return quantizedType.cast().getStorageType(); + return llvm::cast(quantizedType).getStorageType(); } - if (quantizedType.isa()) { + if (llvm::isa(quantizedType)) { // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - ShapedType sType = quantizedType.cast(); - if (!sType.getElementType().isa()) { + ShapedType sType = llvm::cast(quantizedType); + if (!llvm::isa(sType.getElementType())) { return nullptr; } Type storageType = - sType.getElementType().cast().getStorageType(); - if (quantizedType.isa()) { + llvm::cast(sType.getElementType()).getStorageType(); + if (llvm::isa(quantizedType)) { return RankedTensorType::get(sType.getShape(), storageType); } - if (quantizedType.isa()) { + if (llvm::isa(quantizedType)) { return UnrankedTensorType::get(storageType); } - if (quantizedType.isa()) { + if (llvm::isa(quantizedType)) { return VectorType::get(sType.getShape(), storageType); } } @@ -155,21 +156,21 @@ // i.e. f32 -> quant<"uniform[i8:f32]{1.0}"> return *this; } - if (candidateType.isa()) { - ShapedType candidateShapedType = candidateType.cast(); + if (llvm::isa(candidateType)) { + ShapedType candidateShapedType = llvm::cast(candidateType); if (candidateShapedType.getElementType() != getExpressedType()) { return nullptr; } - if (candidateType.isa()) { + if (llvm::isa(candidateType)) { // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> return RankedTensorType::get(candidateShapedType.getShape(), *this); } - if (candidateType.isa()) { + if (llvm::isa(candidateType)) { // i.e. tensor -> tensor> return UnrankedTensorType::get(*this); } - if (candidateType.isa()) { + if (llvm::isa(candidateType)) { // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> return VectorType::get(candidateShapedType.getShape(), *this); } @@ -179,25 +180,25 @@ } Type QuantizedType::castToExpressedType(Type quantizedType) { - if (quantizedType.isa()) { + if (llvm::isa(quantizedType)) { // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32 - return quantizedType.cast().getExpressedType(); + return llvm::cast(quantizedType).getExpressedType(); } - if (quantizedType.isa()) { + if (llvm::isa(quantizedType)) { // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - ShapedType sType = quantizedType.cast(); - if (!sType.getElementType().isa()) { + ShapedType sType = llvm::cast(quantizedType); + if (!llvm::isa(sType.getElementType())) { return nullptr; } Type expressedType = - sType.getElementType().cast().getExpressedType(); - if (quantizedType.isa()) { + llvm::cast(sType.getElementType()).getExpressedType(); + if (llvm::isa(quantizedType)) { return RankedTensorType::get(sType.getShape(), expressedType); } - if (quantizedType.isa()) { + if (llvm::isa(quantizedType)) { return UnrankedTensorType::get(expressedType); } - if (quantizedType.isa()) { + if (llvm::isa(quantizedType)) { return VectorType::get(sType.getShape(), expressedType); } } @@ -243,7 +244,7 @@ // Verify that the expressed type is floating point. // If this restriction is ever eliminated, the parser/printer must be // extended. - if (expressedType && !expressedType.isa()) + if (expressedType && !llvm::isa(expressedType)) return emitError() << "expressed type must be floating point"; return success(); @@ -284,7 +285,7 @@ // Verify that the expressed type is floating point. // If this restriction is ever eliminated, the parser/printer must be // extended. - if (!expressedType.isa()) + if (!llvm::isa(expressedType)) return emitError() << "expressed type must be floating point"; // Verify scale. @@ -338,7 +339,7 @@ // Verify that the expressed type is floating point. // If this restriction is ever eliminated, the parser/printer must be // extended. - if (!expressedType.isa()) + if (!llvm::isa(expressedType)) return emitError() << "expressed type must be floating point"; // Ensure that the number of scales and zeroPoints match. @@ -385,7 +386,7 @@ // Verify that the expressed type is floating point. // If this restriction is ever eliminated, the parser/printer must be // extended. - if (!expressedType.isa()) + if (!llvm::isa(expressedType)) return emitError() << "expressed type must be floating point"; if (max <= min) return emitError() << "illegal min and max: (" << min << ":" << max << ")"; diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp --- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp @@ -420,13 +420,13 @@ /// Print a type registered to this dialect. void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const { - if (auto anyType = type.dyn_cast()) + if (auto anyType = llvm::dyn_cast(type)) printAnyQuantizedType(anyType, os); - else if (auto uniformType = type.dyn_cast()) + else if (auto uniformType = llvm::dyn_cast(type)) printUniformQuantizedType(uniformType, os); - else if (auto perAxisType = type.dyn_cast()) + else if (auto perAxisType = llvm::dyn_cast(type)) printUniformQuantizedPerAxisType(perAxisType, os); - else if (auto calibratedType = type.dyn_cast()) + else if (auto calibratedType = llvm::dyn_cast(type)) printCalibratedQuantizedType(calibratedType, os); else llvm_unreachable("Unhandled quantized type"); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -495,7 +495,7 @@ Region &ForOp::getLoopBody() { return getRegion(); } ForOp mlir::scf::getForInductionVarOwner(Value val) { - auto ivArg = val.dyn_cast(); + auto ivArg = llvm::dyn_cast(val); if (!ivArg) return ForOp(); assert(ivArg.getOwner() && "unlinked block argument"); @@ -576,7 +576,7 @@ }; Value srcVal = mapping.lookupOrDefault(src); - if (srcVal.getType().isa()) { + if (llvm::isa(srcVal.getType())) { results.push_back(rewriter.create( forallOp.getLoc(), dst.getType(), srcVal, mapping.lookupOrDefault(dst), @@ -890,7 +890,8 @@ replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand, Value replacement) { Type oldType = operand.get().getType(), newType = replacement.getType(); - assert(oldType.isa() && newType.isa() && + assert(llvm::isa(oldType) && + llvm::isa(newType) && "expected ranked tensor types"); // 1. Create new iter operands, exactly 1 is replaced. @@ -1074,7 +1075,7 @@ cast(forOp.getRegion().front().getTerminator()); Value yieldVal = yieldOp->getOperand(idx); auto tensorLoadOp = yieldVal.getDefiningOp(); - bool isTensor = bbArg.getType().isa(); + bool isTensor = llvm::isa(bbArg.getType()); bufferization::ToMemrefOp tensorToMemref; // Either bbArg has no use or it has a single buffer_cast use. @@ -1445,7 +1446,7 @@ } ForallOp mlir::scf::getForallOpThreadIndexOwner(Value val) { - auto tidxArg = val.dyn_cast(); + auto tidxArg = llvm::dyn_cast(val); if (!tidxArg) return ForallOp(); assert(tidxArg.getOwner() && "unlinked block argument"); @@ -1464,7 +1465,8 @@ if (!forallOp) return failure(); Value sharedOut = - forallOp.getTiedOpOperand(dimOp.getSource().cast())->get(); + forallOp.getTiedOpOperand(llvm::cast(dimOp.getSource())) + ->get(); rewriter.updateRootInPlace( dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); }); return success(); @@ -1744,7 +1746,7 @@ llvm::map_range(getYieldingOps(), [](Operation &op) { // Add new ops here as needed. auto insertSliceOp = cast(&op); - return insertSliceOp.getDest().cast(); + return llvm::cast(insertSliceOp.getDest()); })); } @@ -1964,7 +1966,7 @@ // Otherwise, the successor is dependent on the condition. bool condition; - if (auto condAttr = operands.front().dyn_cast_or_null()) { + if (auto condAttr = llvm::dyn_cast_or_null(operands.front())) { condition = condAttr.getValue().isOne(); } else { // If the condition isn't constant, both regions may be executed. @@ -2006,7 +2008,7 @@ void IfOp::getRegionInvocationBounds( ArrayRef operands, SmallVectorImpl &invocationBounds) { - if (auto cond = operands[0].dyn_cast_or_null()) { + if (auto cond = llvm::dyn_cast_or_null(operands[0])) { // If the condition is known, then one region is known to be executed once // and the other zero times. invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0); @@ -2542,7 +2544,7 @@ // come from the same scf.if. for (const auto &tup : llvm::enumerate(thenYield)) { if (tup.value().getDefiningOp() == nestedIf) { - auto nestedIdx = tup.value().cast().getResultNumber(); + auto nestedIdx = llvm::cast(tup.value()).getResultNumber(); if (nestedIf.elseYield().getOperand(nestedIdx) != elseYield[tup.index()]) { return failure(); @@ -2818,7 +2820,7 @@ Region &ParallelOp::getLoopBody() { return getRegion(); } ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) { - auto ivArg = val.dyn_cast(); + auto ivArg = llvm::dyn_cast(val); if (!ivArg) return ParallelOp(); assert(ivArg.getOwner() && "unlinked block argument"); @@ -3130,7 +3132,7 @@ // Try to narrow the successor to the condition region. assert(!operands.empty() && "expected at least one operand"); - auto cond = operands[0].dyn_cast_or_null(); + auto cond = llvm::dyn_cast_or_null(operands[0]); if (!cond || !cond.getValue()) regions.emplace_back(getResults()); if (!cond || cond.getValue()) @@ -3360,7 +3362,7 @@ // block argument or the initial value of i-th before block argument. If // the comparison results `true`, i-th before block argument is a loop // invariant. - auto yieldOpBlockArg = yieldOpArg.dyn_cast(); + auto yieldOpBlockArg = llvm::dyn_cast(yieldOpArg); if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) { Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()]; if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) { @@ -3392,7 +3394,7 @@ // before block argument or the initial value of i-th before block // argument. If the comparison results `true`, i-th before block // argument is a loop invariant. - auto yieldOpBlockArg = yieldOpArg.dyn_cast(); + auto yieldOpBlockArg = llvm::dyn_cast(yieldOpArg); if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) { Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()]; if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) { @@ -3960,7 +3962,7 @@ } // If a constant was not provided, all regions are possible successors. - auto operandValue = operands.front().dyn_cast_or_null(); + auto operandValue = llvm::dyn_cast_or_null(operands.front()); if (!operandValue) { for (Region &caseRegion : getCaseRegions()) successors.emplace_back(&caseRegion); @@ -3981,7 +3983,7 @@ void IndexSwitchOp::getRegionInvocationBounds( ArrayRef operands, SmallVectorImpl &bounds) { - auto operandValue = operands.front().dyn_cast_or_null(); + auto operandValue = llvm::dyn_cast_or_null(operands.front()); if (!operandValue) { // All regions are invoked at most once. bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1)); diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp @@ -27,10 +27,10 @@ ValueBoundsConstraintSet &cstr) { // `value` is an iter_arg or an OpResult. int64_t iterArgIdx; - if (auto iterArg = value.dyn_cast()) { + if (auto iterArg = llvm::dyn_cast(value)) { iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars(); } else { - iterArgIdx = value.cast().getResultNumber(); + iterArgIdx = llvm::cast(value).getResultNumber(); } // An EQ constraint can be added if the yielded value (dimension size) @@ -63,7 +63,7 @@ bound, boundOperands, BoundType::EQ, yieldedValue, dim, [&](Value v, std::optional d) { // Stop when reaching a block argument of the loop body. - if (auto bbArg = v.dyn_cast()) + if (auto bbArg = llvm::dyn_cast(v)) return bbArg.getOwner()->getParentOp() == forOp; // Stop when reaching a value that is defined outside of the loop. It // is impossible to reach an iter_arg from there. diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp @@ -145,18 +145,20 @@ } uint32_t spirv::InterfaceVarABIAttr::getBinding() { - return getImpl()->binding.cast().getInt(); + return llvm::cast(getImpl()->binding).getInt(); } uint32_t spirv::InterfaceVarABIAttr::getDescriptorSet() { - return getImpl()->descriptorSet.cast().getInt(); + return llvm::cast(getImpl()->descriptorSet).getInt(); } std::optional spirv::InterfaceVarABIAttr::getStorageClass() { if (getImpl()->storageClass) return static_cast( - getImpl()->storageClass.cast().getValue().getZExtValue()); + llvm::cast(getImpl()->storageClass) + .getValue() + .getZExtValue()); return std::nullopt; } @@ -170,7 +172,7 @@ return emitError() << "expected 32-bit integer for binding"; if (storageClass) { - if (auto storageClassAttr = storageClass.cast()) { + if (auto storageClassAttr = llvm::cast(storageClass)) { auto storageClassValue = spirv::symbolizeStorageClass(storageClassAttr.getInt()); if (!storageClassValue) @@ -219,14 +221,14 @@ spirv::Version spirv::VerCapExtAttr::getVersion() { return static_cast( - getImpl()->version.cast().getValue().getZExtValue()); + llvm::cast(getImpl()->version).getValue().getZExtValue()); } spirv::VerCapExtAttr::ext_iterator::ext_iterator(ArrayAttr::iterator it) : llvm::mapped_iterator( it, [](Attribute attr) { - return *symbolizeExtension(attr.cast().getValue()); + return *symbolizeExtension(llvm::cast(attr).getValue()); }) {} spirv::VerCapExtAttr::ext_range spirv::VerCapExtAttr::getExtensions() { @@ -235,7 +237,7 @@ } ArrayAttr spirv::VerCapExtAttr::getExtensionsAttr() { - return getImpl()->extensions.cast(); + return llvm::cast(getImpl()->extensions); } spirv::VerCapExtAttr::cap_iterator::cap_iterator(ArrayAttr::iterator it) @@ -243,7 +245,7 @@ spirv::Capability (*)(Attribute)>( it, [](Attribute attr) { return *symbolizeCapability( - attr.cast().getValue().getZExtValue()); + llvm::cast(attr).getValue().getZExtValue()); }) {} spirv::VerCapExtAttr::cap_range spirv::VerCapExtAttr::getCapabilities() { @@ -252,7 +254,7 @@ } ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() { - return getImpl()->capabilities.cast(); + return llvm::cast(getImpl()->capabilities); } LogicalResult @@ -263,7 +265,7 @@ return emitError() << "expected 32-bit integer for version"; if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) { - if (auto intAttr = attr.dyn_cast()) + if (auto intAttr = llvm::dyn_cast(attr)) if (spirv::symbolizeCapability(intAttr.getValue().getZExtValue())) return true; return false; @@ -271,7 +273,7 @@ return emitError() << "unknown capability in capability list"; if (!llvm::all_of(extensions.getValue(), [](Attribute attr) { - if (auto strAttr = attr.dyn_cast()) + if (auto strAttr = llvm::dyn_cast(attr)) if (spirv::symbolizeExtension(strAttr.getValue())) return true; return false; @@ -297,7 +299,7 @@ StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; } spirv::VerCapExtAttr spirv::TargetEnvAttr::getTripleAttr() const { - return getImpl()->triple.cast(); + return llvm::cast(getImpl()->triple); } spirv::Version spirv::TargetEnvAttr::getVersion() const { @@ -337,7 +339,7 @@ } spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() const { - return getImpl()->limits.cast(); + return llvm::cast(getImpl()->limits); } //===----------------------------------------------------------------------===// @@ -628,7 +630,7 @@ [&](spirv::Capability cap) { os << spirv::stringifyCapability(cap); }); printer << "], ["; llvm::interleaveComma(triple.getExtensionsAttr(), os, [&](Attribute attr) { - os << attr.cast().getValue(); + os << llvm::cast(attr).getValue(); }); printer << "]>"; } @@ -669,11 +671,11 @@ if (succeeded(generatedAttributePrinter(attr, printer))) return; - if (auto targetEnv = attr.dyn_cast()) + if (auto targetEnv = llvm::dyn_cast(attr)) print(targetEnv, printer); - else if (auto vceAttr = attr.dyn_cast()) + else if (auto vceAttr = llvm::dyn_cast(attr)) print(vceAttr, printer); - else if (auto interfaceVarABIAttr = attr.dyn_cast()) + else if (auto interfaceVarABIAttr = llvm::dyn_cast(attr)) print(interfaceVarABIAttr, printer); else llvm_unreachable("unhandled SPIR-V attribute kind"); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -33,9 +33,9 @@ if (!attr) return std::nullopt; - if (auto boolAttr = attr.dyn_cast()) + if (auto boolAttr = llvm::dyn_cast(attr)) return boolAttr.getValue(); - if (auto splatAttr = attr.dyn_cast()) + if (auto splatAttr = llvm::dyn_cast(attr)) if (splatAttr.getElementType().isInteger(1)) return splatAttr.getSplatValue(); return std::nullopt; @@ -52,12 +52,12 @@ if (indices.empty()) return composite; - if (auto vector = composite.dyn_cast()) { + if (auto vector = llvm::dyn_cast(composite)) { assert(indices.size() == 1 && "must have exactly one index for a vector"); return vector.getValues()[indices[0]]; } - if (auto array = composite.dyn_cast()) { + if (auto array = llvm::dyn_cast(composite)) { assert(!indices.empty() && "must have at least one index for an array"); return extractCompositeElement(array.getValue()[indices[0]], indices.drop_front()); @@ -149,7 +149,7 @@ if (auto constructOp = getComposite().getDefiningOp()) { - auto type = constructOp.getType().cast(); + auto type = llvm::cast(constructOp.getType()); if (getIndices().size() == 1 && constructOp.getConstituents().size() == type.getNumElements()) { auto i = getIndices().begin()->cast(); @@ -159,7 +159,7 @@ auto indexVector = llvm::to_vector<8>(llvm::map_range(getIndices(), [](Attribute attr) { - return static_cast(attr.cast().getInt()); + return static_cast(llvm::cast(attr).getInt()); })); return extractCompositeElement(adaptor.getComposite(), indexVector); } @@ -434,10 +434,9 @@ // "Before version 1.4, Result Type must be a pointer, scalar, or vector. // Starting with version 1.4, Result Type can additionally be a composite type // other than a vector." - bool isScalarOrVector = trueBrStoreOp.getValue() - .getType() - .cast() - .isScalarOrVector(); + bool isScalarOrVector = + llvm::cast(trueBrStoreOp.getValue().getType()) + .isScalarOrVector(); // Check that each `spirv.Store` uses the same pointer, memory access // attributes and a valid type of the value. diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -164,19 +164,19 @@ return type; // Check other allowed types - if (auto t = type.dyn_cast()) { + if (auto t = llvm::dyn_cast(type)) { if (type.isBF16()) { parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types"); return Type(); } - } else if (auto t = type.dyn_cast()) { + } else if (auto t = llvm::dyn_cast(type)) { if (!ScalarType::isValid(t)) { parser.emitError(typeLoc, "only 1/8/16/32/64-bit integer type allowed but found ") << type; return Type(); } - } else if (auto t = type.dyn_cast()) { + } else if (auto t = llvm::dyn_cast(type)) { if (t.getRank() != 1) { parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t; return Type(); @@ -203,7 +203,7 @@ if (parser.parseType(type)) return Type(); - if (auto t = type.dyn_cast()) { + if (auto t = llvm::dyn_cast(type)) { if (t.getRank() != 1) { parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t; return Type(); @@ -216,7 +216,7 @@ return Type(); } - if (!t.getElementType().isa()) { + if (!llvm::isa(t.getElementType())) { parser.emitError(typeLoc, "matrix columns' elements must be of " "Float type, got ") << t.getElementType(); @@ -239,7 +239,7 @@ if (parser.parseType(type)) return Type(); - if (!type.isa()) { + if (!llvm::isa(type)) { parser.emitError(typeLoc, "sampled image must be composed using image type, got ") << type; @@ -939,12 +939,12 @@ Attribute attr = attribute.getValue(); if (symbol == spirv::getEntryPointABIAttrName()) { - if (!attr.isa()) { + if (!llvm::isa(attr)) { return op->emitError("'") << symbol << "' attribute must be an entry point ABI attribute"; } } else if (symbol == spirv::getTargetEnvAttrName()) { - if (!attr.isa()) + if (!llvm::isa(attr)) return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr"; } else { return op->emitError("found unsupported '") @@ -965,7 +965,7 @@ return emitError(loc, "found unsupported '") << symbol << "' attribute on region argument"; - auto varABIAttr = attr.dyn_cast(); + auto varABIAttr = llvm::dyn_cast(attr); if (!varABIAttr) return emitError(loc, "'") << symbol << "' must be a spirv::InterfaceVarABIAttr"; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -81,7 +81,7 @@ parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(type)) return failure(); - auto fnType = type.dyn_cast(); + auto fnType = llvm::dyn_cast(type); if (!fnType) { parser.emitError(loc, "expected function type"); return failure(); @@ -141,7 +141,7 @@ return failure(); } auto valueAttr = constOp.getValue(); - auto integerValueAttr = valueAttr.dyn_cast(); + auto integerValueAttr = llvm::dyn_cast(valueAttr); if (!integerValueAttr) { return failure(); } @@ -181,11 +181,11 @@ if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(), attrName, attr)) return failure(); - if (!attrVal.isa()) + if (!llvm::isa(attrVal)) return parser.emitError(loc, "expected ") << attrName << " attribute specified as string"; - auto attrOptional = - spirv::symbolizeEnum(attrVal.cast().getValue()); + auto attrOptional = spirv::symbolizeEnum( + llvm::cast(attrVal).getValue()); if (!attrOptional) return parser.emitError(loc, "invalid ") << attrName << " attribute specification: " << attrVal; @@ -430,23 +430,23 @@ Type resultType = op->getResult(0).getType(); // ODS checks that result type and operand type have the same shape. - if (auto vectorType = operandType.dyn_cast()) { + if (auto vectorType = llvm::dyn_cast(operandType)) { operandType = vectorType.getElementType(); - resultType = resultType.cast().getElementType(); + resultType = llvm::cast(resultType).getElementType(); } if (auto coopMatrixType = - operandType.dyn_cast()) { + llvm::dyn_cast(operandType)) { operandType = coopMatrixType.getElementType(); resultType = - resultType.cast().getElementType(); + llvm::cast(resultType).getElementType(); } if (auto jointMatrixType = - operandType.dyn_cast()) { + llvm::dyn_cast(operandType)) { operandType = jointMatrixType.getElementType(); resultType = - resultType.cast().getElementType(); + llvm::cast(resultType).getElementType(); } auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth(); @@ -490,7 +490,7 @@ return success(); } - auto memAccess = memAccessAttr.template cast(); + auto memAccess = llvm::cast(memAccessAttr); if (!memAccess) { return memoryOp.emitOpError("invalid memory access specifier: ") @@ -534,7 +534,7 @@ return success(); } - auto memAccess = memAccessAttr.template cast(); + auto memAccess = llvm::cast(memAccessAttr); if (!memAccess) { return memoryOp.emitOpError("invalid memory access specifier: ") @@ -589,7 +589,7 @@ // TODO: Check that the value type satisfies restrictions of // SPIR-V OpLoad/OpStore operations if (val.getType() != - ptr.getType().cast().getPointeeType()) { + llvm::cast(ptr.getType()).getPointeeType()) { return op.emitOpError("mismatch in result type and pointer type"); } return success(); @@ -599,10 +599,11 @@ static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val) { auto valType = val.getType(); - if (auto valVecTy = valType.dyn_cast()) + if (auto valVecTy = llvm::dyn_cast(valType)) valType = valVecTy.getElementType(); - if (valType != ptr.getType().cast().getPointeeType()) { + if (valType != + llvm::cast(ptr.getType()).getPointeeType()) { return op.emitOpError("mismatch in result type and pointer type"); } return success(); @@ -674,7 +675,7 @@ // Get bit width of types. static unsigned getBitWidth(Type type) { - if (type.isa()) { + if (llvm::isa(type)) { // Just return 64 bits for pointer types for now. // TODO: Make sure not caller relies on the actual pointer width value. return 64; @@ -683,7 +684,7 @@ if (type.isIntOrFloat()) return type.getIntOrFloatBitWidth(); - if (auto vectorType = type.dyn_cast()) { + if (auto vectorType = llvm::dyn_cast(type)) { assert(vectorType.getElementType().isIntOrFloat()); return vectorType.getNumElements() * vectorType.getElementType().getIntOrFloatBitWidth(); @@ -703,7 +704,7 @@ } for (auto index : indices) { - if (auto cType = type.dyn_cast()) { + if (auto cType = llvm::dyn_cast(type)) { if (cType.hasCompileTimeKnownNumElements() && (index < 0 || static_cast(index) >= cType.getNumElements())) { @@ -723,7 +724,7 @@ static Type getElementType(Type type, Attribute indices, function_ref emitErrorFn) { - auto indicesArrayAttr = indices.dyn_cast(); + auto indicesArrayAttr = llvm::dyn_cast(indices); if (!indicesArrayAttr) { emitErrorFn("expected a 32-bit integer array attribute for 'indices'"); return nullptr; @@ -735,7 +736,7 @@ SmallVector indexVals; for (auto indexAttr : indicesArrayAttr) { - auto indexIntAttr = indexAttr.dyn_cast(); + auto indexIntAttr = llvm::dyn_cast(indexAttr); if (!indexIntAttr) { emitErrorFn("expected an 32-bit integer for index, but found '") << indexAttr << "'"; @@ -769,7 +770,7 @@ template static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) { - auto resultType = op.getType().template cast(); + auto resultType = llvm::cast(op.getType()); if (resultType.getNumElements() != 2) return op.emitOpError("expected result struct type containing two members"); @@ -794,7 +795,7 @@ if (parser.parseType(resultType)) return failure(); - auto structType = resultType.dyn_cast(); + auto structType = llvm::dyn_cast(resultType); if (!structType || structType.getNumElements() != 2) return parser.emitError(loc, "expected spirv.struct type with two members"); @@ -836,7 +837,7 @@ parser.getCurrentLocation(&loc) || parser.parseColonType(type)) return failure(); - auto ptrType = type.dyn_cast(); + auto ptrType = llvm::dyn_cast(type); if (!ptrType) return parser.emitError(loc, "expected pointer type"); @@ -877,9 +878,9 @@ // Verifies an atomic update op. template static LogicalResult verifyAtomicUpdateOp(Operation *op) { - auto ptrType = op->getOperand(0).getType().cast(); + auto ptrType = llvm::cast(op->getOperand(0).getType()); auto elementType = ptrType.getPointeeType(); - if (!elementType.isa()) + if (!llvm::isa(elementType)) return op->emitOpError() << "pointer operand must point to an " << stringifyTypeName() << " value, found " << elementType; @@ -990,7 +991,7 @@ static Type getUnaryOpResultType(Type operandType) { Builder builder(operandType.getContext()); Type resultType = builder.getIntegerType(1); - if (auto vecType = operandType.dyn_cast()) + if (auto vecType = llvm::dyn_cast(operandType)) return VectorType::get(vecType.getNumElements(), resultType); return resultType; } @@ -1010,7 +1011,7 @@ //===----------------------------------------------------------------------===// static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) { - auto ptrType = type.dyn_cast(); + auto ptrType = llvm::dyn_cast(type); if (!ptrType) { emitError(baseLoc, "'spirv.AccessChain' op expected a pointer " "to composite type, but provided ") @@ -1023,7 +1024,7 @@ int32_t index = 0; for (auto indexSSA : indices) { - auto cType = resultType.dyn_cast(); + auto cType = llvm::dyn_cast(resultType); if (!cType) { emitError( baseLoc, @@ -1032,7 +1033,7 @@ return nullptr; } index = 0; - if (resultType.isa()) { + if (llvm::isa(resultType)) { Operation *op = indexSSA.getDefiningOp(); if (!op) { emitError(baseLoc, "'spirv.AccessChain' op index must be an " @@ -1134,7 +1135,7 @@ return failure(); auto providedResultType = - accessChainOp.getType().template dyn_cast(); + llvm::dyn_cast(accessChainOp.getType()); if (!providedResultType) return accessChainOp.emitOpError( "result type must be a pointer, but provided") @@ -1201,7 +1202,7 @@ if (parser.parseColonType(type)) return failure(); - auto ptrType = type.dyn_cast(); + auto ptrType = llvm::dyn_cast(type); if (!ptrType) return parser.emitError(loc, "expected pointer type"); @@ -1231,10 +1232,9 @@ "result, but found ") << atomOp.getComparator().getType() << " vs " << atomOp.getType(); - Type pointeeType = atomOp.getPointer() - .getType() - .template cast() - .getPointeeType(); + Type pointeeType = + llvm::cast(atomOp.getPointer().getType()) + .getPointeeType(); if (atomOp.getType() != pointeeType) return atomOp.emitOpError( "pointer operand's pointee type must have the same " @@ -1322,7 +1322,7 @@ if (parser.parseColonType(type)) return failure(); - auto ptrType = type.dyn_cast(); + auto ptrType = llvm::dyn_cast(type); if (!ptrType) return parser.emitError(loc, "expected pointer type"); @@ -1340,7 +1340,7 @@ << getValue().getType() << " vs " << getType(); Type pointeeType = - getPointer().getType().cast().getPointeeType(); + llvm::cast(getPointer().getType()).getPointeeType(); if (getType() != pointeeType) return emitOpError("pointer operand's pointee type must have the same " "as the op result type, but found ") @@ -1537,13 +1537,13 @@ if (operandType == resultType) { return emitError("result type must be different from operand type"); } - if (operandType.isa() && - !resultType.isa()) { + if (llvm::isa(operandType) && + !llvm::isa(resultType)) { return emitError( "unhandled bit cast conversion from pointer type to non-pointer type"); } - if (!operandType.isa() && - resultType.isa()) { + if (!llvm::isa(operandType) && + llvm::isa(resultType)) { return emitError( "unhandled bit cast conversion from non-pointer type to pointer type"); } @@ -1562,8 +1562,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::PtrCastToGenericOp::verify() { - auto operandType = getPointer().getType().cast(); - auto resultType = getResult().getType().cast(); + auto operandType = llvm::cast(getPointer().getType()); + auto resultType = llvm::cast(getResult().getType()); spirv::StorageClass operandStorage = operandType.getStorageClass(); if (operandStorage != spirv::StorageClass::Workgroup && @@ -1590,8 +1590,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::GenericCastToPtrOp::verify() { - auto operandType = getPointer().getType().cast(); - auto resultType = getResult().getType().cast(); + auto operandType = llvm::cast(getPointer().getType()); + auto resultType = llvm::cast(getResult().getType()); spirv::StorageClass operandStorage = operandType.getStorageClass(); if (operandStorage != spirv::StorageClass::Generic) @@ -1618,8 +1618,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::GenericCastToPtrExplicitOp::verify() { - auto operandType = getPointer().getType().cast(); - auto resultType = getResult().getType().cast(); + auto operandType = llvm::cast(getPointer().getType()); + auto resultType = llvm::cast(getResult().getType()); spirv::StorageClass operandStorage = operandType.getStorageClass(); if (operandStorage != spirv::StorageClass::Generic) @@ -1719,7 +1719,7 @@ if (auto weights = getBranchWeights()) { printer << " ["; llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) { - printer << a.cast().getInt(); + printer << llvm::cast(a).getInt(); }); printer << "]"; } @@ -1736,7 +1736,7 @@ return emitOpError("must have exactly two branch weights"); } if (llvm::all_of(*weights, [](Attribute attr) { - return attr.cast().getValue().isZero(); + return llvm::cast(attr).getValue().isZero(); })) return emitOpError("branch weights cannot both be zero"); } @@ -1749,10 +1749,10 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::CompositeConstructOp::verify() { - auto cType = getType().cast(); + auto cType = llvm::cast(getType()); operand_range constituents = this->getConstituents(); - if (auto coopType = cType.dyn_cast()) { + if (auto coopType = llvm::dyn_cast(cType)) { if (constituents.size() != 1) return emitOpError("has incorrect number of operands: expected ") << "1, but provided " << constituents.size(); @@ -1763,7 +1763,7 @@ return success(); } - if (auto jointType = cType.dyn_cast()) { + if (auto jointType = llvm::dyn_cast(cType)) { if (constituents.size() != 1) return emitOpError("has incorrect number of operands: expected ") << "1, but provided " << constituents.size(); @@ -1787,7 +1787,7 @@ // If not constructing a cooperative matrix type, then we must be constructing // a vector type. - auto resultType = cType.dyn_cast(); + auto resultType = llvm::dyn_cast(cType); if (!resultType) return emitOpError( "expected to return a vector or cooperative matrix when the number of " @@ -1795,14 +1795,14 @@ SmallVector sizes; for (Value component : constituents) { - if (!component.getType().isa() && + if (!llvm::isa(component.getType()) && !component.getType().isIntOrFloat()) return emitOpError("operand type mismatch: expected operand to have " "a scalar or vector type, but provided ") << component.getType(); Type elementType = component.getType(); - if (auto vectorType = component.getType().dyn_cast()) { + if (auto vectorType = llvm::dyn_cast(component.getType())) { sizes.push_back(vectorType.getNumElements()); elementType = vectorType.getElementType(); } else { @@ -1866,7 +1866,7 @@ } LogicalResult spirv::CompositeExtractOp::verify() { - auto indicesArrayAttr = getIndices().dyn_cast(); + auto indicesArrayAttr = llvm::dyn_cast(getIndices()); auto resultType = getElementType(getComposite().getType(), indicesArrayAttr, getLoc()); if (!resultType) @@ -1909,7 +1909,7 @@ } LogicalResult spirv::CompositeInsertOp::verify() { - auto indicesArrayAttr = getIndices().dyn_cast(); + auto indicesArrayAttr = llvm::dyn_cast(getIndices()); auto objectType = getElementType(getComposite().getType(), indicesArrayAttr, getLoc()); if (!objectType) @@ -1946,9 +1946,9 @@ return failure(); Type type = NoneType::get(parser.getContext()); - if (auto typedAttr = value.dyn_cast()) + if (auto typedAttr = llvm::dyn_cast(value)) type = typedAttr.getType(); - if (type.isa()) { + if (llvm::isa(type)) { if (parser.parseColonType(type)) return failure(); } @@ -1958,25 +1958,25 @@ void spirv::ConstantOp::print(OpAsmPrinter &printer) { printer << ' ' << getValue(); - if (getType().isa()) + if (llvm::isa(getType())) printer << " : " << getType(); } static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType) { - if (value.isa()) { - auto valueType = value.cast().getType(); + if (llvm::isa(value)) { + auto valueType = llvm::cast(value).getType(); if (valueType != opType) return op.emitOpError("result type (") << opType << ") does not match value type (" << valueType << ")"; return success(); } - if (value.isa()) { - auto valueType = value.cast().getType(); + if (llvm::isa(value)) { + auto valueType = llvm::cast(value).getType(); if (valueType == opType) return success(); - auto arrayType = opType.dyn_cast(); - auto shapedType = valueType.dyn_cast(); + auto arrayType = llvm::dyn_cast(opType); + auto shapedType = llvm::dyn_cast(valueType); if (!arrayType) return op.emitOpError("result or element type (") << opType << ") does not match value type (" << valueType @@ -1984,7 +1984,7 @@ int numElements = arrayType.getNumElements(); auto opElemType = arrayType.getElementType(); - while (auto t = opElemType.dyn_cast()) { + while (auto t = llvm::dyn_cast(opElemType)) { numElements *= t.getNumElements(); opElemType = t.getElementType(); } @@ -2005,8 +2005,8 @@ } return success(); } - if (auto arrayAttr = value.dyn_cast()) { - auto arrayType = opType.dyn_cast(); + if (auto arrayAttr = llvm::dyn_cast(value)) { + auto arrayType = llvm::dyn_cast(opType); if (!arrayType) return op.emitOpError( "must have spirv.array result type for array value"); @@ -2030,12 +2030,12 @@ bool spirv::ConstantOp::isBuildableWith(Type type) { // Must be valid SPIR-V type first. - if (!type.isa()) + if (!llvm::isa(type)) return false; if (isa(type.getDialect())) { // TODO: support constant struct - return type.isa(); + return llvm::isa(type); } return true; @@ -2043,7 +2043,7 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc, OpBuilder &builder) { - if (auto intType = type.dyn_cast()) { + if (auto intType = llvm::dyn_cast(type)) { unsigned width = intType.getWidth(); if (width == 1) return builder.create(loc, type, @@ -2051,19 +2051,19 @@ return builder.create( loc, type, builder.getIntegerAttr(type, APInt(width, 0))); } - if (auto floatType = type.dyn_cast()) { + if (auto floatType = llvm::dyn_cast(type)) { return builder.create( loc, type, builder.getFloatAttr(floatType, 0.0)); } - if (auto vectorType = type.dyn_cast()) { + if (auto vectorType = llvm::dyn_cast(type)) { Type elemType = vectorType.getElementType(); - if (elemType.isa()) { + if (llvm::isa(elemType)) { return builder.create( loc, type, DenseElementsAttr::get(vectorType, IntegerAttr::get(elemType, 0).getValue())); } - if (elemType.isa()) { + if (llvm::isa(elemType)) { return builder.create( loc, type, DenseFPElementsAttr::get(vectorType, @@ -2076,7 +2076,7 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc, OpBuilder &builder) { - if (auto intType = type.dyn_cast()) { + if (auto intType = llvm::dyn_cast(type)) { unsigned width = intType.getWidth(); if (width == 1) return builder.create(loc, type, @@ -2084,19 +2084,19 @@ return builder.create( loc, type, builder.getIntegerAttr(type, APInt(width, 1))); } - if (auto floatType = type.dyn_cast()) { + if (auto floatType = llvm::dyn_cast(type)) { return builder.create( loc, type, builder.getFloatAttr(floatType, 1.0)); } - if (auto vectorType = type.dyn_cast()) { + if (auto vectorType = llvm::dyn_cast(type)) { Type elemType = vectorType.getElementType(); - if (elemType.isa()) { + if (llvm::isa(elemType)) { return builder.create( loc, type, DenseElementsAttr::get(vectorType, IntegerAttr::get(elemType, 1).getValue())); } - if (elemType.isa()) { + if (llvm::isa(elemType)) { return builder.create( loc, type, DenseFPElementsAttr::get(vectorType, @@ -2115,9 +2115,9 @@ llvm::raw_svector_ostream specialName(specialNameBuffer); specialName << "cst"; - IntegerType intTy = type.dyn_cast(); + IntegerType intTy = llvm::dyn_cast(type); - if (IntegerAttr intCst = getValue().dyn_cast()) { + if (IntegerAttr intCst = llvm::dyn_cast(getValue())) { if (intTy && intTy.getWidth() == 1) { return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); } @@ -2131,17 +2131,18 @@ } } - if (intTy || type.isa()) { + if (intTy || llvm::isa(type)) { specialName << '_' << type; } - if (auto vecType = type.dyn_cast()) { + if (auto vecType = llvm::dyn_cast(type)) { specialName << "_vec_"; specialName << vecType.getDimSize(0); Type elementType = vecType.getElementType(); - if (elementType.isa() || elementType.isa()) { + if (llvm::isa(elementType) || + llvm::isa(elementType)) { specialName << "x" << elementType; } } @@ -2210,9 +2211,10 @@ auto resultType = getResult().getType(); // ODS checks that vector result type and vector operand type have the same // shape. - if (auto vectorType = operandType.dyn_cast()) { + if (auto vectorType = llvm::dyn_cast(operandType)) { unsigned operandNumElements = vectorType.getNumElements(); - unsigned resultNumElements = resultType.cast().getNumElements(); + unsigned resultNumElements = + llvm::cast(resultType).getNumElements(); if (operandNumElements != resultNumElements) { return emitOpError( "operand and result must have same number of elements"); @@ -2230,9 +2232,10 @@ auto resultType = getResult().getType(); // ODS checks that vector result type and vector operand type have the same // shape. - if (auto vectorType = operandType.dyn_cast()) { + if (auto vectorType = llvm::dyn_cast(operandType)) { unsigned operandNumElements = vectorType.getNumElements(); - unsigned resultNumElements = resultType.cast().getNumElements(); + unsigned resultNumElements = + llvm::cast(resultType).getNumElements(); if (operandNumElements != resultNumElements) { return emitOpError( "operand and result must have same number of elements"); @@ -2331,7 +2334,7 @@ if (parser.parseAttribute(value, i32Type, "value", attr)) { return failure(); } - values.push_back(value.cast().getInt()); + values.push_back(llvm::cast(value).getInt()); } result.addAttribute(kValuesAttrName, parser.getBuilder().getI32ArrayAttr(values)); @@ -2347,7 +2350,7 @@ return; printer << ", "; llvm::interleaveComma(values, printer, [&](Attribute a) { - printer << a.cast().getInt(); + printer << llvm::cast(a).getInt(); }); } @@ -2677,7 +2680,7 @@ if (parser.parseColonType(type)) { return failure(); } - if (!type.isa()) { + if (!llvm::isa(type)) { return parser.emitError(loc, "expected spirv.ptr type"); } result.addAttribute(kTypeAttrName, TypeAttr::get(type)); @@ -2708,7 +2711,7 @@ } LogicalResult spirv::GlobalVariableOp::verify() { - if (!getType().isa()) + if (!llvm::isa(getType())) return emitOpError("result must be of a !spv.ptr type"); // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the @@ -2748,7 +2751,7 @@ if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); - if (auto localIdTy = getLocalid().getType().dyn_cast()) + if (auto localIdTy = llvm::dyn_cast(getLocalid().getType())) if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3) return emitOpError("localid is a vector and can be with only " " 2 or 3 components, actual number is ") @@ -2839,7 +2842,7 @@ } auto ptrType = spirv::PointerType::get(elementType, storageClass); - if (auto valVecTy = elementType.dyn_cast()) + if (auto valVecTy = llvm::dyn_cast(elementType)) ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass); if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) { @@ -2879,7 +2882,7 @@ } auto ptrType = spirv::PointerType::get(elementType, storageClass); - if (auto valVecTy = elementType.dyn_cast()) + if (auto valVecTy = llvm::dyn_cast(elementType)) ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass); if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc, @@ -3148,7 +3151,7 @@ void spirv::LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr, MemoryAccessAttr memoryAccess, IntegerAttr alignment) { - auto ptrType = basePtr.getType().cast(); + auto ptrType = llvm::cast(basePtr.getType()); build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess, alignment); } @@ -3177,7 +3180,7 @@ void spirv::LoadOp::print(OpAsmPrinter &printer) { SmallVector elidedAttrs; StringRef sc = stringifyStorageClass( - getPtr().getType().cast().getStorageClass()); + llvm::cast(getPtr().getType()).getStorageClass()); printer << " \"" << sc << "\" " << getPtr(); printMemoryAccessAttribute(*this, printer, elidedAttrs); @@ -3494,7 +3497,7 @@ } if (auto interface = entryPointOp.getInterface()) { for (Attribute varRef : interface) { - auto varSymRef = varRef.dyn_cast(); + auto varSymRef = llvm::dyn_cast(varRef); if (!varSymRef) { return entryPointOp.emitError( "expected symbol reference for interface " @@ -3587,8 +3590,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::SelectOp::verify() { - if (auto conditionTy = getCondition().getType().dyn_cast()) { - auto resultVectorTy = getResult().getType().dyn_cast(); + if (auto conditionTy = llvm::dyn_cast(getCondition().getType())) { + auto resultVectorTy = llvm::dyn_cast(getResult().getType()); if (!resultVectorTy) { return emitOpError("result expected to be of vector type when " "condition is of vector type"); @@ -3760,9 +3763,9 @@ return emitOpError("SpecId cannot be negative"); auto value = getDefaultValue(); - if (value.isa()) { + if (llvm::isa(value)) { // Make sure bitwidth is allowed. - if (!value.getType().isa()) + if (!llvm::isa(value.getType())) return emitOpError("default value bitwidth disallowed"); return success(); } @@ -3798,7 +3801,7 @@ void spirv::StoreOp::print(OpAsmPrinter &printer) { SmallVector elidedAttrs; StringRef sc = stringifyStorageClass( - getPtr().getType().cast().getStorageClass()); + llvm::cast(getPtr().getType()).getStorageClass()); printer << " \"" << sc << "\" " << getPtr() << ", " << getValue(); printMemoryAccessAttribute(*this, printer, elidedAttrs); @@ -3861,7 +3864,7 @@ if (parser.parseType(type)) return failure(); - auto ptrType = type.dyn_cast(); + auto ptrType = llvm::dyn_cast(type); if (!ptrType) return parser.emitError(loc, "expected spirv.ptr type"); result.addTypes(ptrType); @@ -3901,7 +3904,7 @@ "spirv.GlobalVariable for module-level variables."); } - auto pointerType = getPointer().getType().cast(); + auto pointerType = llvm::cast(getPointer().getType()); if (getStorageClass() != pointerType.getStorageClass()) return emitOpError( "storage class must match result pointer's storage class"); @@ -3940,7 +3943,7 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::VectorShuffleOp::verify() { - VectorType resultType = getType().cast(); + VectorType resultType = llvm::cast(getType()); size_t numResultElements = resultType.getNumElements(); if (numResultElements != getComponents().size()) @@ -3950,8 +3953,8 @@ << getComponents().size() << ")"; size_t totalSrcElements = - getVector1().getType().cast().getNumElements() + - getVector2().getType().cast().getNumElements(); + llvm::cast(getVector1().getType()).getNumElements() + + llvm::cast(getVector2().getType()).getNumElements(); for (const auto &selector : getComponents().getAsValueRange()) { uint32_t index = selector.getZExtValue(); @@ -4001,13 +4004,14 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer, Type coopMatrix) { - Type pointeeType = pointer.cast().getPointeeType(); - if (!pointeeType.isa() && !pointeeType.isa()) + Type pointeeType = llvm::cast(pointer).getPointeeType(); + if (!llvm::isa(pointeeType) && + !llvm::isa(pointeeType)) return op->emitError( "Pointer must point to a scalar or vector type but provided ") << pointeeType; spirv::StorageClass storage = - pointer.cast().getStorageClass(); + llvm::cast(pointer).getStorageClass(); if (storage != spirv::StorageClass::Workgroup && storage != spirv::StorageClass::StorageBuffer && storage != spirv::StorageClass::PhysicalStorageBuffer) @@ -4071,10 +4075,11 @@ verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) { if (op.getC().getType() != op.getResult().getType()) return op.emitOpError("result and third operand must have the same type"); - auto typeA = op.getA().getType().cast(); - auto typeB = op.getB().getType().cast(); - auto typeC = op.getC().getType().cast(); - auto typeR = op.getResult().getType().cast(); + auto typeA = llvm::cast(op.getA().getType()); + auto typeB = llvm::cast(op.getB().getType()); + auto typeC = llvm::cast(op.getC().getType()); + auto typeR = + llvm::cast(op.getResult().getType()); if (typeA.getRows() != typeR.getRows() || typeA.getColumns() != typeB.getRows() || typeB.getColumns() != typeR.getColumns()) @@ -4086,8 +4091,8 @@ auto elementTypeA = typeA.getElementType(); auto elementTypeB = typeB.getElementType(); if (isa(elementTypeA) && isa(elementTypeB)) { - if (elementTypeA.cast().getWidth() != - elementTypeB.cast().getWidth()) + if (llvm::cast(elementTypeA).getWidth() != + llvm::cast(elementTypeB).getWidth()) return op.emitOpError( "matrix A and B integer element types must be the same bit width"); } else if (elementTypeA != elementTypeB) { @@ -4105,13 +4110,14 @@ static LogicalResult verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) { - Type pointeeType = pointer.cast().getPointeeType(); - if (!pointeeType.isa() && !pointeeType.isa()) + Type pointeeType = llvm::cast(pointer).getPointeeType(); + if (!llvm::isa(pointeeType) && + !llvm::isa(pointeeType)) return op->emitError( "Pointer must point to a scalar or vector type but provided ") << pointeeType; spirv::StorageClass storage = - pointer.cast().getStorageClass(); + llvm::cast(pointer).getStorageClass(); if (storage != spirv::StorageClass::Workgroup && storage != spirv::StorageClass::CrossWorkgroup && storage != spirv::StorageClass::UniformConstant && @@ -4147,10 +4153,11 @@ static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) { if (op.getC().getType() != op.getResult().getType()) return op.emitOpError("result and third operand must have the same type"); - auto typeA = op.getA().getType().cast(); - auto typeB = op.getB().getType().cast(); - auto typeC = op.getC().getType().cast(); - auto typeR = op.getResult().getType().cast(); + auto typeA = llvm::cast(op.getA().getType()); + auto typeB = llvm::cast(op.getB().getType()); + auto typeC = llvm::cast(op.getC().getType()); + auto typeR = + llvm::cast(op.getResult().getType()); if (typeA.getRows() != typeR.getRows() || typeA.getColumns() != typeB.getRows() || typeB.getColumns() != typeR.getColumns()) @@ -4174,8 +4181,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::MatrixTimesScalarOp::verify() { - if (auto inputCoopmat = - getMatrix().getType().dyn_cast()) { + if (auto inputCoopmat = llvm::dyn_cast( + getMatrix().getType())) { if (inputCoopmat.getElementType() != getScalar().getType()) return emitError("input matrix components' type and scaling value must " "have the same type"); @@ -4183,7 +4190,7 @@ } // Check that the scalar type is the same as the matrix element type. - auto inputMatrix = getMatrix().getType().cast(); + auto inputMatrix = llvm::cast(getMatrix().getType()); if (getScalar().getType() != inputMatrix.getElementType()) return emitError("input matrix components' type and scaling value must " "have the same type"); @@ -4199,11 +4206,11 @@ printer << ' '; StringRef targetStorageClass = stringifyStorageClass( - getTarget().getType().cast().getStorageClass()); + llvm::cast(getTarget().getType()).getStorageClass()); printer << " \"" << targetStorageClass << "\" " << getTarget() << ", "; StringRef sourceStorageClass = stringifyStorageClass( - getSource().getType().cast().getStorageClass()); + llvm::cast(getSource().getType()).getStorageClass()); printer << " \"" << sourceStorageClass << "\" " << getSource(); SmallVector elidedAttrs; @@ -4215,7 +4222,7 @@ printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); Type pointeeType = - getTarget().getType().cast().getPointeeType(); + llvm::cast(getTarget().getType()).getPointeeType(); printer << " : " << pointeeType; } @@ -4263,10 +4270,10 @@ LogicalResult spirv::CopyMemoryOp::verify() { Type targetType = - getTarget().getType().cast().getPointeeType(); + llvm::cast(getTarget().getType()).getPointeeType(); Type sourceType = - getSource().getType().cast().getPointeeType(); + llvm::cast(getSource().getType()).getPointeeType(); if (targetType != sourceType) return emitOpError("both operands must be pointers to the same type"); @@ -4290,8 +4297,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::TransposeOp::verify() { - auto inputMatrix = getMatrix().getType().cast(); - auto resultMatrix = getResult().getType().cast(); + auto inputMatrix = llvm::cast(getMatrix().getType()); + auto resultMatrix = llvm::cast(getResult().getType()); // Verify that the input and output matrices have correct shapes. if (inputMatrix.getNumRows() != resultMatrix.getNumColumns()) @@ -4315,9 +4322,9 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::MatrixTimesMatrixOp::verify() { - auto leftMatrix = getLeftmatrix().getType().cast(); - auto rightMatrix = getRightmatrix().getType().cast(); - auto resultMatrix = getResult().getType().cast(); + auto leftMatrix = llvm::cast(getLeftmatrix().getType()); + auto rightMatrix = llvm::cast(getRightmatrix().getType()); + auto resultMatrix = llvm::cast(getResult().getType()); // left matrix columns' count and right matrix rows' count must be equal if (leftMatrix.getNumColumns() != rightMatrix.getNumRows()) @@ -4403,16 +4410,16 @@ } LogicalResult spirv::SpecConstantCompositeOp::verify() { - auto cType = getType().dyn_cast(); + auto cType = llvm::dyn_cast(getType()); auto constituents = this->getConstituents().getValue(); if (!cType) return emitError("result type must be a composite type, but provided ") << getType(); - if (cType.isa()) + if (llvm::isa(cType)) return emitError("unsupported composite type ") << cType; - if (cType.isa()) + if (llvm::isa(cType)) return emitError("unsupported composite type ") << cType; if (constituents.size() != cType.getNumElements()) return emitError("has incorrect number of operands: expected ") @@ -4420,7 +4427,7 @@ << constituents.size(); for (auto index : llvm::seq(0, constituents.size())) { - auto constituent = constituents[index].cast(); + auto constituent = llvm::cast(constituents[index]); auto constituentSpecConstOp = dyn_cast(SymbolTable::lookupNearestSymbolFrom( @@ -4498,19 +4505,19 @@ LogicalResult spirv::GLFrexpStructOp::verify() { spirv::StructType structTy = - getResult().getType().dyn_cast(); + llvm::dyn_cast(getResult().getType()); if (structTy.getNumElements() != 2) return emitError("result type must be a struct type with two memebers"); Type significandTy = structTy.getElementType(0); Type exponentTy = structTy.getElementType(1); - VectorType exponentVecTy = exponentTy.dyn_cast(); - IntegerType exponentIntTy = exponentTy.dyn_cast(); + VectorType exponentVecTy = llvm::dyn_cast(exponentTy); + IntegerType exponentIntTy = llvm::dyn_cast(exponentTy); Type operandTy = getOperand().getType(); - VectorType operandVecTy = operandTy.dyn_cast(); - FloatType operandFTy = operandTy.dyn_cast(); + VectorType operandVecTy = llvm::dyn_cast(operandTy); + FloatType operandFTy = llvm::dyn_cast(operandTy); if (significandTy != operandTy) return emitError("member zero of the resulting struct type must be the " @@ -4518,7 +4525,7 @@ if (exponentVecTy) { IntegerType componentIntTy = - exponentVecTy.getElementType().dyn_cast(); + llvm::dyn_cast(exponentVecTy.getElementType()); if (!componentIntTy || componentIntTy.getWidth() != 32) return emitError("member one of the resulting struct type must" "be a scalar or vector of 32 bit integer type"); @@ -4547,11 +4554,12 @@ Type significandType = getX().getType(); Type exponentType = getExp().getType(); - if (significandType.isa() != exponentType.isa()) + if (llvm::isa(significandType) != + llvm::isa(exponentType)) return emitOpError("operands must both be scalars or vectors"); auto getNumElements = [](Type type) -> unsigned { - if (auto vectorType = type.dyn_cast()) + if (auto vectorType = llvm::dyn_cast(type)) return vectorType.getNumElements(); return 1; }; @@ -4567,17 +4575,19 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::ImageDrefGatherOp::verify() { - VectorType resultType = getResult().getType().cast(); + VectorType resultType = llvm::cast(getResult().getType()); auto sampledImageType = - getSampledimage().getType().cast(); - auto imageType = sampledImageType.getImageType().cast(); + llvm::cast(getSampledimage().getType()); + auto imageType = + llvm::cast(sampledImageType.getImageType()); if (resultType.getNumElements() != 4) return emitOpError("result type must be a vector of four components"); Type elementType = resultType.getElementType(); Type sampledElementType = imageType.getElementType(); - if (!sampledElementType.isa() && elementType != sampledElementType) + if (!llvm::isa(sampledElementType) && + elementType != sampledElementType) return emitOpError( "the component type of result must be the same as sampled type of the " "underlying image type"); @@ -4629,7 +4639,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::ImageQuerySizeOp::verify() { - spirv::ImageType imageType = getImage().getType().cast(); + spirv::ImageType imageType = + llvm::cast(getImage().getType()); Type resultType = getResult().getType(); spirv::Dim dim = imageType.getDim(); @@ -4677,7 +4688,7 @@ componentNumber += 1; unsigned resultComponentNumber = 1; - if (auto resultVectorType = resultType.dyn_cast()) + if (auto resultVectorType = llvm::dyn_cast(resultType)) resultComponentNumber = resultVectorType.getNumElements(); if (componentNumber != resultComponentNumber) @@ -4798,7 +4809,7 @@ LogicalResult spirv::VectorTimesScalarOp::verify() { if (getVector().getType() != getType()) return emitOpError("vector operand and result type mismatch"); - auto scalarType = getType().cast().getElementType(); + auto scalarType = llvm::cast(getType()).getElementType(); if (getScalar().getType() != scalarType) return emitOpError("scalar operand and result element type match"); return success(); @@ -4851,11 +4862,11 @@ return op->emitOpError("requires the same type for both vector operands"); unsigned expectedNumAttrs = 0; - if (auto intTy = factorTy.dyn_cast()) { + if (auto intTy = llvm::dyn_cast(factorTy)) { ++expectedNumAttrs; auto packedVectorFormat = - op->getAttr(kPackedVectorFormatAttrName) - .dyn_cast_or_null(); + llvm::dyn_cast_or_null( + op->getAttr(kPackedVectorFormatAttrName)); if (!packedVectorFormat) return op->emitOpError("requires Packed Vector Format attribute for " "integer vector operands"); @@ -4927,9 +4938,9 @@ SmallVector, 1> capabilities = {dotProductCap}; Type factorTy = op->getOperand(0).getType(); - if (auto intTy = factorTy.dyn_cast()) { - auto formatAttr = op->getAttr(kPackedVectorFormatAttrName) - .cast(); + if (auto intTy = llvm::dyn_cast(factorTy)) { + auto formatAttr = llvm::cast( + op->getAttr(kPackedVectorFormatAttrName)); if (formatAttr.getValue() == spirv::PackedVectorFormat::PackedVectorFormat4x8Bit) capabilities.push_back(dotProductInput4x8BitPackedCap); @@ -4937,7 +4948,7 @@ return capabilities; } - auto vecTy = factorTy.cast(); + auto vecTy = llvm::cast(factorTy); if (vecTy.getElementTypeBitWidth() == 8) { capabilities.push_back(dotProductInput4x8BitCap); return capabilities; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -68,17 +68,18 @@ void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional storage) { - getElementType().cast().getExtensions(extensions, storage); + llvm::cast(getElementType()).getExtensions(extensions, storage); } void ArrayType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage) { - getElementType().cast().getCapabilities(capabilities, storage); + llvm::cast(getElementType()) + .getCapabilities(capabilities, storage); } std::optional ArrayType::getSizeInBytes() { - auto elementType = getElementType().cast(); + auto elementType = llvm::cast(getElementType()); std::optional size = elementType.getSizeInBytes(); if (!size) return std::nullopt; @@ -90,11 +91,11 @@ //===----------------------------------------------------------------------===// bool CompositeType::classof(Type type) { - if (auto vectorType = type.dyn_cast()) + if (auto vectorType = llvm::dyn_cast(type)) return isValid(vectorType); - return type.isa(); + return llvm::isa(type); } bool CompositeType::isValid(VectorType type) { @@ -108,7 +109,7 @@ default: return false; } - return type.getRank() == 1 && type.getElementType().isa(); + return type.getRank() == 1 && llvm::isa(type.getElementType()); } Type CompositeType::getElementType(unsigned index) const { @@ -160,8 +161,8 @@ MatrixType, RuntimeArrayType, StructType>( [&](auto type) { type.getExtensions(extensions, storage); }) .Case([&](VectorType type) { - return type.getElementType().cast().getExtensions( - extensions, storage); + return llvm::cast(type.getElementType()) + .getExtensions(extensions, storage); }) .Default([](Type) { llvm_unreachable("invalid composite type"); }); } @@ -180,8 +181,8 @@ ArrayRef ref(caps, std::size(caps)); capabilities.push_back(ref); } - return type.getElementType().cast().getCapabilities( - capabilities, storage); + return llvm::cast(type.getElementType()) + .getCapabilities(capabilities, storage); }) .Default([](Type) { llvm_unreachable("invalid composite type"); }); } @@ -193,7 +194,7 @@ return structType.getSizeInBytes(); if (auto vectorType = dyn_cast()) { std::optional elementSize = - vectorType.getElementType().cast().getSizeInBytes(); + llvm::cast(vectorType.getElementType()).getSizeInBytes(); if (!elementSize) return std::nullopt; return *elementSize * vectorType.getNumElements(); @@ -249,7 +250,7 @@ void CooperativeMatrixNVType::getExtensions( SPIRVType::ExtensionArrayRefVector &extensions, std::optional storage) { - getElementType().cast().getExtensions(extensions, storage); + llvm::cast(getElementType()).getExtensions(extensions, storage); static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix}; ArrayRef ref(exts, std::size(exts)); extensions.push_back(ref); @@ -258,7 +259,8 @@ void CooperativeMatrixNVType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage) { - getElementType().cast().getCapabilities(capabilities, storage); + llvm::cast(getElementType()) + .getCapabilities(capabilities, storage); static const Capability caps[] = {Capability::CooperativeMatrixNV}; ArrayRef ref(caps, std::size(caps)); capabilities.push_back(ref); @@ -317,7 +319,7 @@ void JointMatrixINTELType::getExtensions( SPIRVType::ExtensionArrayRefVector &extensions, std::optional storage) { - getElementType().cast().getExtensions(extensions, storage); + llvm::cast(getElementType()).getExtensions(extensions, storage); static const Extension exts[] = {Extension::SPV_INTEL_joint_matrix}; ArrayRef ref(exts, std::size(exts)); extensions.push_back(ref); @@ -326,7 +328,8 @@ void JointMatrixINTELType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage) { - getElementType().cast().getCapabilities(capabilities, storage); + llvm::cast(getElementType()) + .getCapabilities(capabilities, storage); static const Capability caps[] = {Capability::JointMatrixINTEL}; ArrayRef ref(caps, std::size(caps)); capabilities.push_back(ref); @@ -489,8 +492,8 @@ std::optional storage) { // Use this pointer type's storage class because this pointer indicates we are // using the pointee type in that specific storage class. - getPointeeType().cast().getExtensions(extensions, - getStorageClass()); + llvm::cast(getPointeeType()) + .getExtensions(extensions, getStorageClass()); if (auto scExts = spirv::getExtensions(getStorageClass())) extensions.push_back(*scExts); @@ -501,8 +504,8 @@ std::optional storage) { // Use this pointer type's storage class because this pointer indicates we are // using the pointee type in that specific storage class. - getPointeeType().cast().getCapabilities(capabilities, - getStorageClass()); + llvm::cast(getPointeeType()) + .getCapabilities(capabilities, getStorageClass()); if (auto scCaps = spirv::getCapabilities(getStorageClass())) capabilities.push_back(*scCaps); @@ -547,7 +550,7 @@ void RuntimeArrayType::getExtensions( SPIRVType::ExtensionArrayRefVector &extensions, std::optional storage) { - getElementType().cast().getExtensions(extensions, storage); + llvm::cast(getElementType()).getExtensions(extensions, storage); } void RuntimeArrayType::getCapabilities( @@ -558,7 +561,8 @@ ArrayRef ref(caps, std::size(caps)); capabilities.push_back(ref); } - getElementType().cast().getCapabilities(capabilities, storage); + llvm::cast(getElementType()) + .getCapabilities(capabilities, storage); } //===----------------------------------------------------------------------===// @@ -566,10 +570,10 @@ //===----------------------------------------------------------------------===// bool ScalarType::classof(Type type) { - if (auto floatType = type.dyn_cast()) { + if (auto floatType = llvm::dyn_cast(type)) { return isValid(floatType); } - if (auto intType = type.dyn_cast()) { + if (auto intType = llvm::dyn_cast(type)) { return isValid(intType); } return false; @@ -723,9 +727,9 @@ // Allow SPIR-V dialect types if (llvm::isa(type.getDialect())) return true; - if (type.isa()) + if (llvm::isa(type)) return true; - if (auto vectorType = type.dyn_cast()) + if (auto vectorType = llvm::dyn_cast(type)) return CompositeType::isValid(vectorType); return false; } @@ -815,7 +819,7 @@ LogicalResult SampledImageType::verify(function_ref emitError, Type imageType) { - if (!imageType.isa()) + if (!llvm::isa(imageType)) return emitError() << "expected image type"; return success(); @@ -824,13 +828,13 @@ void SampledImageType::getExtensions( SPIRVType::ExtensionArrayRefVector &extensions, std::optional storage) { - getImageType().cast().getExtensions(extensions, storage); + llvm::cast(getImageType()).getExtensions(extensions, storage); } void SampledImageType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage) { - getImageType().cast().getCapabilities(capabilities, storage); + llvm::cast(getImageType()).getCapabilities(capabilities, storage); } //===----------------------------------------------------------------------===// @@ -1125,14 +1129,14 @@ void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional storage) { for (Type elementType : getElementTypes()) - elementType.cast().getExtensions(extensions, storage); + llvm::cast(elementType).getExtensions(extensions, storage); } void StructType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage) { for (Type elementType : getElementTypes()) - elementType.cast().getCapabilities(capabilities, storage); + llvm::cast(elementType).getCapabilities(capabilities, storage); } llvm::hash_code spirv::hash_value( @@ -1186,7 +1190,7 @@ return emitError() << "matrix columns must be vectors of floats"; /// The underlying vectors (columns) must be of size 2, 3, or 4 - ArrayRef columnShape = columnType.cast().getShape(); + ArrayRef columnShape = llvm::cast(columnType).getShape(); if (columnShape.size() != 1) return emitError() << "matrix columns must be 1D vectors"; @@ -1198,8 +1202,8 @@ /// Returns true if the matrix elements are vectors of float elements bool MatrixType::isValidColumnType(Type columnType) { - if (auto vectorType = columnType.dyn_cast()) { - if (vectorType.getElementType().isa()) + if (auto vectorType = llvm::dyn_cast(columnType)) { + if (llvm::isa(vectorType.getElementType())) return true; } return false; @@ -1208,13 +1212,13 @@ Type MatrixType::getColumnType() const { return getImpl()->columnType; } Type MatrixType::getElementType() const { - return getImpl()->columnType.cast().getElementType(); + return llvm::cast(getImpl()->columnType).getElementType(); } unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; } unsigned MatrixType::getNumRows() const { - return getImpl()->columnType.cast().getShape()[0]; + return llvm::cast(getImpl()->columnType).getShape()[0]; } unsigned MatrixType::getNumElements() const { @@ -1223,7 +1227,7 @@ void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional storage) { - getColumnType().cast().getExtensions(extensions, storage); + llvm::cast(getColumnType()).getExtensions(extensions, storage); } void MatrixType::getCapabilities( @@ -1235,7 +1239,7 @@ capabilities.push_back(ref); } // Add any capabilities associated with the underlying vectors (i.e., columns) - getColumnType().cast().getCapabilities(capabilities, storage); + llvm::cast(getColumnType()).getCapabilities(capabilities, storage); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -41,14 +41,14 @@ } bool shape::isExtentTensorType(Type type) { - auto ranked = type.dyn_cast(); + auto ranked = llvm::dyn_cast(type); return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex(); } LogicalResult shape::getShapeVec(Value input, SmallVectorImpl &shapeValues) { if (auto inputOp = input.getDefiningOp()) { - auto type = inputOp.getArg().getType().cast(); + auto type = llvm::cast(inputOp.getArg().getType()); if (!type.hasRank()) return failure(); llvm::append_range(shapeValues, type.getShape()); @@ -64,7 +64,7 @@ static bool isErrorPropagationPossible(TypeRange operandTypes) { return llvm::any_of(operandTypes, [](Type ty) { - return ty.isa(); + return llvm::isa(ty); }); } @@ -72,7 +72,7 @@ assert(op != nullptr && op->getNumResults() == 1); Type resultTy = op->getResultTypes().front(); if (isErrorPropagationPossible(op->getOperandTypes())) { - if (!resultTy.isa()) + if (!llvm::isa(resultTy)) return op->emitOpError() << "if at least one of the operands can hold error values then " "the result must be of type `size` to propagate them"; @@ -84,7 +84,7 @@ assert(op != nullptr && op->getNumResults() == 1); Type resultTy = op->getResultTypes().front(); if (isErrorPropagationPossible(op->getOperandTypes())) { - if (!resultTy.isa()) + if (!llvm::isa(resultTy)) return op->emitOpError() << "if at least one of the operands can hold error values then " "the result must be of type `shape` to propagate them"; @@ -94,7 +94,7 @@ template static bool eachHasOnlyOneOfTypes(TypeRange typeRange) { - return typeRange.size() == 1 && typeRange.front().isa(); + return typeRange.size() == 1 && llvm::isa(typeRange.front()); } template @@ -147,13 +147,15 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - if (type.isa() || isExtentTensorType(type)) - return builder.create(loc, type, - value.cast()); - if (type.isa()) - return builder.create(loc, type, value.cast()); - if (type.isa()) - return builder.create(loc, type, value.cast()); + if (llvm::isa(type) || isExtentTensorType(type)) + return builder.create( + loc, type, llvm::cast(value)); + if (llvm::isa(type)) + return builder.create(loc, type, + llvm::cast(value)); + if (llvm::isa(type)) + return builder.create(loc, type, + llvm::cast(value)); return arith::ConstantOp::materialize(builder, value, type, loc); } @@ -165,7 +167,7 @@ return op->emitError( "shape.lib attribute may only be on op implementing SymbolTable"); - if (auto symbolRef = attribute.getValue().dyn_cast()) { + if (auto symbolRef = llvm::dyn_cast(attribute.getValue())) { auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); if (!symbol) return op->emitError("shape function library ") @@ -176,17 +178,17 @@ << symbolRef << " required to be shape function library"; } - if (auto arr = attribute.getValue().dyn_cast()) { + if (auto arr = llvm::dyn_cast(attribute.getValue())) { // Verify all entries are function libraries and mappings in libraries // refer to unique ops. DenseSet key; for (auto it : arr) { - if (!it.isa()) + if (!llvm::isa(it)) return op->emitError( "only SymbolRefAttr allowed in shape.lib attribute array"); auto shapeFnLib = dyn_cast( - SymbolTable::lookupSymbolIn(op, it.cast())); + SymbolTable::lookupSymbolIn(op, llvm::cast(it))); if (!shapeFnLib) return op->emitError() << it << " does not refer to FunctionLibraryOp"; @@ -395,8 +397,8 @@ MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - if (operands[0].getType().isa() || - operands[1].getType().isa()) + if (llvm::isa(operands[0].getType()) || + llvm::isa(operands[1].getType())) inferredReturnTypes.assign({SizeType::get(context)}); else inferredReturnTypes.assign({IndexType::get(context)}); @@ -617,7 +619,7 @@ getOperation()->eraseOperand(idx); // Always false if any input is statically known false - if (!a.cast().getValue()) + if (!llvm::cast(a).getValue()) return a; } // If this is reached, all inputs were statically known passing. @@ -651,9 +653,11 @@ if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1]) return nullptr; auto lhsShape = llvm::to_vector<6>( - adaptor.getShapes()[0].cast().getValues()); + llvm::cast(adaptor.getShapes()[0]) + .getValues()); auto rhsShape = llvm::to_vector<6>( - adaptor.getShapes()[1].cast().getValues()); + llvm::cast(adaptor.getShapes()[1]) + .getValues()); SmallVector resultShape; // If the shapes are not compatible, we can't fold it. @@ -677,7 +681,8 @@ LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { auto isPotentiallyNonEmptyShape = [](Value shape) { - if (auto extentTensorTy = shape.getType().dyn_cast()) { + if (auto extentTensorTy = + llvm::dyn_cast(shape.getType())) { if (extentTensorTy.getDimSize(0) == 0) return false; } @@ -714,11 +719,11 @@ // Insert cast if needed. if (replacement.getType() != op.getType()) { auto loc = op.getLoc(); - if (op.getType().isa()) { + if (llvm::isa(op.getType())) { replacement = rewriter.create(loc, replacement); } else { - assert(!op.getType().isa() && - !replacement.getType().isa() && + assert(!llvm::isa(op.getType()) && + !llvm::isa(replacement.getType()) && "expect extent tensor cast"); replacement = rewriter.create(loc, op.getType(), replacement); @@ -781,7 +786,7 @@ if (auto castOp = operand.getDefiningOp()) { // Only eliminate the cast if it holds no shape information. bool isInformationLoosingCast = - castOp.getType().cast().isDynamicDim(0); + llvm::cast(castOp.getType()).isDynamicDim(0); if (isInformationLoosingCast) { anyChange = true; return castOp.getSource(); @@ -807,14 +812,15 @@ LogicalResult matchAndRewrite(BroadcastOp op, PatternRewriter &rewriter) const override { // Only concretize dynamic extent tensor result types. - auto resultTy = op.getType().dyn_cast(); + auto resultTy = llvm::dyn_cast(op.getType()); if (!resultTy || !resultTy.isDynamicDim(0)) return failure(); // Infer resulting shape rank if possible. int64_t maxRank = 0; for (Value shape : op.getShapes()) { - if (auto extentTensorTy = shape.getType().dyn_cast()) { + if (auto extentTensorTy = + llvm::dyn_cast(shape.getType())) { // Cannot infer resulting shape rank if any operand is dynamically // ranked. if (extentTensorTy.isDynamicDim(0)) @@ -883,12 +889,12 @@ NamedAttrList dummy; if (parser.parseAttribute(extentsRaw, "dummy", dummy)) return failure(); - auto extentsArray = extentsRaw.dyn_cast(); + auto extentsArray = llvm::dyn_cast(extentsRaw); if (!extentsArray) return failure(); SmallVector ints; for (Attribute extent : extentsArray) { - IntegerAttr attr = extent.dyn_cast(); + IntegerAttr attr = llvm::dyn_cast(extent); if (!attr) return failure(); ints.push_back(attr.getInt()); @@ -930,7 +936,7 @@ Type lhs = l.front(); Type rhs = r.front(); - if (lhs.isa() || rhs.isa()) + if (llvm::isa(lhs) || llvm::isa(rhs)) // Shape type is compatible with all other valid return types. return true; return lhs == rhs; @@ -956,7 +962,7 @@ static bool hasAtMostSingleNonScalar(ArrayRef attributes) { bool nonScalarSeen = false; for (Attribute a : attributes) { - if (!a || a.cast().getNumElements() != 0) { + if (!a || llvm::cast(a).getNumElements() != 0) { if (nonScalarSeen) return false; nonScalarSeen = true; @@ -1070,13 +1076,13 @@ if (auto constSizeOp = getIndex().getDefiningOp()) return constSizeOp.getValue().getLimitedValue(); if (auto constantOp = getIndex().getDefiningOp()) - return constantOp.getValue().cast().getInt(); + return llvm::cast(constantOp.getValue()).getInt(); return std::nullopt; } OpFoldResult DimOp::fold(FoldAdaptor adaptor) { Type valType = getValue().getType(); - auto valShapedType = valType.dyn_cast(); + auto valShapedType = llvm::dyn_cast(valType); if (!valShapedType || !valShapedType.hasRank()) return nullptr; std::optional index = getConstantIndex(); @@ -1104,7 +1110,7 @@ } LogicalResult mlir::shape::DimOp::verify() { - auto st = getValue().getType().cast(); + auto st = llvm::cast(getValue().getType()); if (!st.hasRank()) return success(); if (auto index = getConstantIndex()) { @@ -1142,8 +1148,8 @@ MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - if (operands[0].getType().isa() || - operands[1].getType().isa()) + if (llvm::isa(operands[0].getType()) || + llvm::isa(operands[1].getType())) inferredReturnTypes.assign({SizeType::get(context)}); else inferredReturnTypes.assign({IndexType::get(context)}); @@ -1199,7 +1205,7 @@ return nullptr; SmallVector extents; for (auto attr : adaptor.getExtents()) - extents.push_back(attr.cast().getInt()); + extents.push_back(llvm::cast(attr).getInt()); Builder builder(getContext()); return builder.getIndexTensorAttr(extents); } @@ -1215,9 +1221,8 @@ } FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { - auto attr = getMapping() - .get(op->getName().getIdentifier()) - .dyn_cast_or_null(); + auto attr = llvm::dyn_cast_or_null( + getMapping().get(op->getName().getIdentifier())); if (!attr) return nullptr; return lookupSymbol(attr); @@ -1329,7 +1334,7 @@ if (auto constSizeOp = getDim().getDefiningOp()) return constSizeOp.getValue().getLimitedValue(); if (auto constantOp = getDim().getDefiningOp()) - return constantOp.getValue().cast().getInt(); + return llvm::cast(constantOp.getValue()).getInt(); return std::nullopt; } @@ -1349,7 +1354,7 @@ int64_t dim) { auto loc = result.location; auto dimAttr = builder.getIndexAttr(dim); - if (shape.getType().isa()) { + if (llvm::isa(shape.getType())) { Value dim = builder.create(loc, dimAttr); build(builder, result, builder.getType(), shape, dim); } else { @@ -1405,7 +1410,7 @@ return failure(); auto isShapeType = [](Type arg) { - if (arg.isa()) + if (llvm::isa(arg)) return true; return isExtentTensorType(arg); }; @@ -1414,29 +1419,29 @@ Type acc = types.front(); for (auto t : drop_begin(types)) { Type l = acc, r = t; - if (!l.isa()) + if (!llvm::isa(l)) std::swap(l, r); // Handle sizes, propagate error type if present. - if (l.isa()) { - if (r.isa()) + if (llvm::isa(l)) { + if (llvm::isa(r)) acc = l; else return emitOptionalError(location, "requires all sizes or shapes"); - } else if (l.isa()) { - if (r.isa()) + } else if (llvm::isa(l)) { + if (llvm::isa(r)) acc = r; else return emitOptionalError(location, "requires all sizes or shapes"); - } else if (l.isa()) { + } else if (llvm::isa(l)) { // Handle shapes, propagate error type if present. if (isShapeType(r)) acc = l; else return emitOptionalError(location, "requires all sizes or shapes"); } else if (isExtentTensorType(l)) { - auto rank1 = l.cast().getShape()[0]; - auto rank2 = r.cast().getShape()[0]; + auto rank1 = llvm::cast(l).getShape()[0]; + auto rank2 = llvm::cast(r).getShape()[0]; if (ShapedType::isDynamic(rank1)) acc = l; else if (ShapedType::isDynamic(rank2)) @@ -1460,13 +1465,13 @@ Type lhs = l.front(); Type rhs = r.front(); - if (!lhs.isa()) + if (!llvm::isa(lhs)) std::swap(lhs, rhs); - if (lhs.isa()) - return rhs.isa(); - if (lhs.isa()) - return rhs.isa(); + if (llvm::isa(lhs)) + return llvm::isa(rhs); + if (llvm::isa(lhs)) + return llvm::isa(rhs); if (succeeded(verifyCompatibleShapes({lhs, rhs}))) return true; @@ -1511,14 +1516,14 @@ if (!shapeOfOp) return failure(); auto rankedTensorType = - shapeOfOp.getArg().getType().dyn_cast(); + llvm::dyn_cast(shapeOfOp.getArg().getType()); if (!rankedTensorType) return failure(); int64_t rank = rankedTensorType.getRank(); - if (op.getType().isa()) { + if (llvm::isa(op.getType())) { rewriter.replaceOpWithNewOp(op.getOperation(), rank); - } else if (op.getType().isa()) { + } else if (llvm::isa(op.getType())) { rewriter.replaceOpWithNewOp(op.getOperation(), rank); } else { return failure(); @@ -1537,7 +1542,7 @@ MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - if (operands[0].getType().isa()) + if (llvm::isa(operands[0].getType())) inferredReturnTypes.assign({SizeType::get(context)}); else inferredReturnTypes.assign({IndexType::get(context)}); @@ -1563,7 +1568,7 @@ return {}; APInt product(64, 1); - for (auto value : shape.cast()) + for (auto value : llvm::cast(shape)) product *= value; Builder builder(getContext()); return builder.getIndexAttr(product.getLimitedValue()); @@ -1573,7 +1578,7 @@ MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - if (operands[0].getType().isa()) + if (llvm::isa(operands[0].getType())) inferredReturnTypes.assign({SizeType::get(context)}); else inferredReturnTypes.assign({IndexType::get(context)}); @@ -1615,9 +1620,9 @@ bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { if (l.size() != 1 || r.size() != 1) return false; - if (l.front().isa() && r.front().isa()) + if (llvm::isa(l.front()) && llvm::isa(r.front())) return true; - if (l.front().isa() && r.front().isa()) + if (llvm::isa(l.front()) && llvm::isa(r.front())) return true; return false; } @@ -1647,9 +1652,9 @@ bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { if (l.size() != 1 || r.size() != 1) return false; - if (l.front().isa() && r.front().isa()) + if (llvm::isa(l.front()) && llvm::isa(r.front())) return true; - if (l.front().isa() && r.front().isa()) + if (llvm::isa(l.front()) && llvm::isa(r.front())) return true; return false; } @@ -1674,8 +1679,8 @@ MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - if (operands[0].getType().isa() || - operands[1].getType().isa()) + if (llvm::isa(operands[0].getType()) || + llvm::isa(operands[1].getType())) inferredReturnTypes.assign({SizeType::get(context)}); else inferredReturnTypes.assign({IndexType::get(context)}); @@ -1694,7 +1699,7 @@ //===----------------------------------------------------------------------===// OpFoldResult ShapeOfOp::fold(FoldAdaptor) { - auto type = getOperand().getType().dyn_cast(); + auto type = llvm::dyn_cast(getOperand().getType()); if (!type || !type.hasStaticShape()) return nullptr; Builder builder(getContext()); @@ -1707,9 +1712,9 @@ LogicalResult matchAndRewrite(shape::ShapeOfOp op, PatternRewriter &rewriter) const override { - if (!op.getArg().getType().isa()) + if (!llvm::isa(op.getArg().getType())) return failure(); - if (op.getType().isa()) + if (llvm::isa(op.getType())) return failure(); rewriter.replaceOpWithNewOp(op.getOperation(), @@ -1732,7 +1737,7 @@ LogicalResult matchAndRewrite(tensor::CastOp op, PatternRewriter &rewriter) const override { - auto ty = op.getType().dyn_cast(); + auto ty = llvm::dyn_cast(op.getType()); if (!ty || ty.getRank() != 1) return failure(); @@ -1741,7 +1746,7 @@ return failure(); // Argument type must be ranked and must not conflict. - auto argTy = shapeOfOp.getArg().getType().dyn_cast(); + auto argTy = llvm::dyn_cast(shapeOfOp.getArg().getType()); if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) return failure(); @@ -1761,10 +1766,10 @@ MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - if (operands[0].getType().isa()) + if (llvm::isa(operands[0].getType())) inferredReturnTypes.assign({ShapeType::get(context)}); else { - auto shapedTy = operands[0].getType().cast(); + auto shapedTy = llvm::cast(operands[0].getType()); int64_t rank = shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic; Type indexTy = IndexType::get(context); @@ -1783,10 +1788,11 @@ Type lhs = l.front(); Type rhs = r.front(); - if (!lhs.isa() || !rhs.isa()) + if (!llvm::isa(lhs) || + !llvm::isa(rhs)) return false; - if (lhs.isa() || rhs.isa()) + if (llvm::isa(lhs) || llvm::isa(rhs)) // Shape type is compatible with all other valid return types. return true; @@ -1819,7 +1825,8 @@ bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { if (inputs.size() != 1 || outputs.size() != 1) return false; - return inputs[0].isa() && outputs[0].isa(); + return llvm::isa(inputs[0]) && + llvm::isa(outputs[0]); } //===----------------------------------------------------------------------===// @@ -1884,16 +1891,16 @@ bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { if (inputs.size() != 1 || outputs.size() != 1) return false; - if (auto inputTensor = inputs[0].dyn_cast()) { - if (!inputTensor.getElementType().isa() || + if (auto inputTensor = llvm::dyn_cast(inputs[0])) { + if (!llvm::isa(inputTensor.getElementType()) || inputTensor.getRank() != 1) return false; - } else if (!inputs[0].isa()) { + } else if (!llvm::isa(inputs[0])) { return false; } - TensorType outputTensor = outputs[0].dyn_cast(); - return outputTensor && outputTensor.getElementType().isa(); + TensorType outputTensor = llvm::dyn_cast(outputs[0]); + return outputTensor && llvm::isa(outputTensor.getElementType()); } //===----------------------------------------------------------------------===// @@ -1911,7 +1918,7 @@ bodyBlock.addArgument(builder.getIndexType(), result.location); Type elementType; - if (auto tensorType = shape.getType().dyn_cast()) + if (auto tensorType = llvm::dyn_cast(shape.getType())) elementType = tensorType.getElementType(); else elementType = SizeType::get(builder.getContext()); @@ -1934,7 +1941,7 @@ << blockArgsCount << " arguments"; // The first block argument is the index and must always be of type `index`. - if (!block.getArgument(0).getType().isa()) + if (!llvm::isa(block.getArgument(0).getType())) return emitOpError( "argument 0 of ReduceOp body is expected to be of IndexType"); @@ -1942,12 +1949,12 @@ // `index`, depending on whether the reduce operation is applied to a shape or // to an extent tensor. Type extentTy = block.getArgument(1).getType(); - if (getShape().getType().isa()) { - if (!extentTy.isa()) + if (llvm::isa(getShape().getType())) { + if (!llvm::isa(extentTy)) return emitOpError("argument 1 of ReduceOp body is expected to be of " "SizeType if the ReduceOp operates on a ShapeType"); } else { - if (!extentTy.isa()) + if (!llvm::isa(extentTy)) return emitOpError( "argument 1 of ReduceOp body is expected to be of IndexType if the " "ReduceOp operates on an extent tensor"); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -261,10 +261,10 @@ if (attrName == "dimLevelType") { Attribute attr; RETURN_ON_FAIL(parser.parseAttribute(attr)); - auto arrayAttr = attr.dyn_cast(); + auto arrayAttr = llvm::dyn_cast(attr); ERROR_IF(!arrayAttr, "expected an array for dimension level types") for (auto i : arrayAttr) { - auto strAttr = i.dyn_cast(); + auto strAttr = llvm::dyn_cast(i); ERROR_IF(!strAttr, "expected a string value in dimension level types") auto strVal = strAttr.getValue(); if (auto optDLT = parseDLT(strVal)) { @@ -279,25 +279,25 @@ } else if (attrName == "dimOrdering") { Attribute attr; RETURN_ON_FAIL(parser.parseAttribute(attr)) - auto affineAttr = attr.dyn_cast(); + auto affineAttr = llvm::dyn_cast(attr); ERROR_IF(!affineAttr, "expected an affine map for dimension ordering") dimOrd = affineAttr.getValue(); } else if (attrName == "higherOrdering") { Attribute attr; RETURN_ON_FAIL(parser.parseAttribute(attr)) - auto affineAttr = attr.dyn_cast(); + auto affineAttr = llvm::dyn_cast(attr); ERROR_IF(!affineAttr, "expected an affine map for higher ordering") higherOrd = affineAttr.getValue(); } else if (attrName == "posWidth") { Attribute attr; RETURN_ON_FAIL(parser.parseAttribute(attr)) - auto intAttr = attr.dyn_cast(); + auto intAttr = llvm::dyn_cast(attr); ERROR_IF(!intAttr, "expected an integral position bitwidth") posWidth = intAttr.getInt(); } else if (attrName == "crdWidth") { Attribute attr; RETURN_ON_FAIL(parser.parseAttribute(attr)) - auto intAttr = attr.dyn_cast(); + auto intAttr = llvm::dyn_cast(attr); ERROR_IF(!intAttr, "expected an integral index bitwidth") crdWidth = intAttr.getInt(); } else if (attrName == "slice") { @@ -305,7 +305,7 @@ // Dispatches to DimSliceAttr to skip mnemonic bool finished = false; while (auto attr = SparseTensorDimSliceAttr::parse(parser, nullptr)) { - auto sliceAttr = attr.cast(); + auto sliceAttr = llvm::cast(attr); slices.push_back(sliceAttr); if (parser.parseOptionalComma().failed()) { finished = true; @@ -442,9 +442,9 @@ SparseTensorEncodingAttr mlir::sparse_tensor::getSparseTensorEncoding(Type type) { - if (auto ttp = type.dyn_cast()) - return ttp.getEncoding().dyn_cast_or_null(); - if (auto mdtp = type.dyn_cast()) + if (auto ttp = llvm::dyn_cast(type)) + return llvm::dyn_cast_or_null(ttp.getEncoding()); + if (auto mdtp = llvm::dyn_cast(type)) return mdtp.getEncoding(); return nullptr; } @@ -725,12 +725,12 @@ } LogicalResult ConvertOp::verify() { - if (auto tp1 = getSource().getType().dyn_cast()) { - if (auto tp2 = getDest().getType().dyn_cast()) { + if (auto tp1 = llvm::dyn_cast(getSource().getType())) { + if (auto tp2 = llvm::dyn_cast(getDest().getType())) { if (tp1.getRank() != tp2.getRank()) return emitError("unexpected conversion mismatch in rank"); auto dstEnc = - tp2.getEncoding().dyn_cast_or_null(); + llvm::dyn_cast_or_null(tp2.getEncoding()); if (dstEnc && dstEnc.isSlice()) return emitError("cannot convert to a sparse tensor slice"); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -112,7 +112,7 @@ Value mem, ArrayRef idxs, Value vmask) { VectorType vtp = vectorType(vl, mem); Value pass = constantZero(rewriter, loc, vtp); - if (idxs.back().getType().isa()) { + if (llvm::isa(idxs.back().getType())) { SmallVector scalarArgs(idxs.begin(), idxs.end()); Value indexVec = idxs.back(); scalarArgs.back() = constantIndex(rewriter, loc, 0); @@ -129,7 +129,7 @@ /// the last index, i.e. back(). static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem, ArrayRef idxs, Value vmask, Value rhs) { - if (idxs.back().getType().isa()) { + if (llvm::isa(idxs.back().getType())) { SmallVector scalarArgs(idxs.begin(), idxs.end()); Value indexVec = idxs.back(); scalarArgs.back() = constantIndex(rewriter, loc, 0); @@ -260,7 +260,7 @@ // innermost loop simply pass through as well. // Example: // a[i][j] for both i and j - if (auto arg = sub.dyn_cast()) { + if (auto arg = llvm::dyn_cast(sub)) { if (isInvariantArg(arg, block) == innermost) return false; if (codegen) @@ -298,8 +298,8 @@ Location loc = forOp.getLoc(); Value vload = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask); - Type etp = vload.getType().cast().getElementType(); - if (!etp.isa()) { + Type etp = llvm::cast(vload.getType()).getElementType(); + if (!llvm::isa(etp)) { if (etp.getIntOrFloatBitWidth() < 32) vload = rewriter.create( loc, vectorType(vl, rewriter.getI32Type()), vload); @@ -318,7 +318,7 @@ Value inv = load.getOperand(0); Value idx = load.getOperand(1); if (isInvariantValue(inv, block)) { - if (auto arg = idx.dyn_cast()) { + if (auto arg = llvm::dyn_cast(idx)) { if (isInvariantArg(arg, block) || !innermost) return false; if (codegen) @@ -369,7 +369,7 @@ if (!VectorType::isValidElementType(exp.getType())) return false; // A block argument is invariant/reduction/index. - if (auto arg = exp.dyn_cast()) { + if (auto arg = llvm::dyn_cast(exp)) { if (arg == forOp.getInductionVar()) { // We encountered a single, innermost index inside the computation, // such as a[i] = i, which must convert to [i, i+1, ...]. diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp @@ -130,7 +130,8 @@ ArrayRef dstStaticShape, ArrayRef reassocation) { return dstStaticShape.size() > - static_cast(src.getType().cast().getRank()) + static_cast( + llvm::cast(src.getType()).getRank()) ? getExpandedOutputShapeFromInputShape( builder, loc, src, dstStaticShape, reassocation) : getCollapsedOutputShapeFromInputShape( @@ -185,7 +186,7 @@ return; } int64_t staticValue = - valueOrAttr.get().cast().getInt(); + llvm::cast(valueOrAttr.get()).getInt(); expr = expr + staticValue; }; addOpFoldResult(lowPad[dim]); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -42,13 +42,13 @@ return op; if (complex::ConstantOp::isBuildableWith(value, type)) return builder.create(loc, type, - value.cast()); + llvm::cast(value)); return nullptr; } SmallVector tensor::getMixedSizes(OpBuilder &builder, Location loc, Value value) { - auto tensorType = value.getType().cast(); + auto tensorType = llvm::cast(value.getType()); SmallVector result; for (int64_t i = 0; i < tensorType.getRank(); ++i) { if (tensorType.isDynamicDim(i)) { @@ -63,7 +63,7 @@ FailureOr tensor::getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult) { - auto tensorType = opResult.getType().dyn_cast(); + auto tensorType = llvm::dyn_cast(opResult.getType()); assert(tensorType && "expected tensor type"); // If the op has a destination, it implements DestinationStyleOpInterface and @@ -100,7 +100,7 @@ Operation *op, SmallVector &result) { for (OpResult opResult : op->getResults()) { - if (opResult.getType().isa()) { + if (llvm::isa(opResult.getType())) { FailureOr destination = getOrCreateDestination(b, loc, opResult); if (failed(destination)) return failure(); @@ -111,8 +111,8 @@ } bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) { - if (auto rtp1 = tp1.dyn_cast()) { - if (auto rtp2 = tp2.dyn_cast()) + if (auto rtp1 = llvm::dyn_cast(tp1)) { + if (auto rtp2 = llvm::dyn_cast(tp2)) return rtp1.getShape() == rtp2.getShape() && rtp1.getElementType() == rtp2.getElementType(); return false; @@ -131,7 +131,7 @@ // Rank-reduced dims must have a static unit dimension. bool isStaticUnitSize = size.value().is() && - size.value().get().cast().getInt() == 1; + llvm::cast(size.value().get()).getInt() == 1; if (shapePos == static_cast(reducedShape.size())) { // There are no more dims in the reduced shape. All remaining sizes must @@ -220,8 +220,8 @@ /// Returns true if `target` is a ranked tensor type that preserves static /// information available in the `source` ranked tensor type. bool mlir::tensor::preservesStaticInformation(Type source, Type target) { - auto sourceType = source.dyn_cast(); - auto targetType = target.dyn_cast(); + auto sourceType = llvm::dyn_cast(source); + auto targetType = llvm::dyn_cast(target); // Requires RankedTensorType. if (!sourceType || !targetType) @@ -322,8 +322,8 @@ if (inputs.size() != 1 || outputs.size() != 1) return false; Type a = inputs.front(), b = outputs.front(); - auto aT = a.dyn_cast(); - auto bT = b.dyn_cast(); + auto aT = llvm::dyn_cast(a); + auto bT = llvm::dyn_cast(b); if (!aT || !bT) return false; @@ -380,9 +380,9 @@ return failure(); auto sourceType = - tensorCastOperand.getOperand().getType().cast(); - auto intermediateType = tensorCastOperand.getType().cast(); - auto resultType = tensorCast.getType().cast(); + llvm::cast(tensorCastOperand.getOperand().getType()); + auto intermediateType = llvm::cast(tensorCastOperand.getType()); + auto resultType = llvm::cast(tensorCast.getType()); // We can remove the intermediate cast if joining all three produces the // same result as just joining the source and result shapes. @@ -427,15 +427,15 @@ tensorCast.getOperand().getDefiningOp(); // Cannot fold cast to unranked tensor. - auto rankedResultType = tensorCast.getType().dyn_cast(); + auto rankedResultType = + llvm::dyn_cast(tensorCast.getType()); if (!rankedResultType) return failure(); if (!extractOperand || !canFoldIntoProducerOp(tensorCast) || - rankedResultType.getShape() == tensorCast.getSource() - .getType() - .cast() - .getShape()) + rankedResultType.getShape() == + llvm::cast(tensorCast.getSource().getType()) + .getShape()) return failure(); SmallVector sizes = extractOperand.getMixedSizes(); @@ -506,7 +506,7 @@ return {}; // Folding for unranked types (UnrankedTensorType) is not supported. - auto tensorType = getSource().getType().dyn_cast(); + auto tensorType = llvm::dyn_cast(getSource().getType()); if (!tensorType) return {}; @@ -527,7 +527,7 @@ // Fold dim to the operand of tensor.generate. if (auto fromElements = dyn_cast_or_null(definingOp)) { auto resultType = - fromElements.getResult().getType().cast(); + llvm::cast(fromElements.getResult().getType()); // The case where the type encodes the size of the dimension is handled // above. assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()])); @@ -751,7 +751,8 @@ if (!producer) return failure(); - auto resultType = castOp->getResult(0).getType().cast(); + auto resultType = + llvm::cast(castOp->getResult(0).getType()); ArrayRef resultShape = resultType.getShape(); SmallVector currMixedSizes = producer.getMixedSizes(); SmallVector newMixedSizes; @@ -765,7 +766,7 @@ // result dim matches. if (auto attr = currDim.dyn_cast()) { if (ShapedType::isDynamic(newDim) || - newDim != attr.cast().getInt()) { + newDim != llvm::cast(attr).getInt()) { // Something is off, the cast result shape cannot be more dynamic // than the empty tensor result shape (enforced by // `canFoldIntoProducer`). Abort for now. @@ -826,7 +827,7 @@ auto tensorCast = extract.getTensor().getDefiningOp(); if (!tensorCast) return failure(); - if (!tensorCast.getSource().getType().isa()) + if (!llvm::isa(tensorCast.getSource().getType())) return failure(); rewriter.replaceOpWithNewOp( extract, tensorCast.getSource(), extract.getIndices()); @@ -843,7 +844,7 @@ LogicalResult ExtractOp::verify() { // Verify the # indices match if we have a ranked type. - auto tensorType = getTensor().getType().cast(); + auto tensorType = llvm::cast(getTensor().getType()); if (tensorType.getRank() != static_cast(getIndices().size())) return emitOpError("incorrect number of indices for extract_element"); return success(); @@ -853,20 +854,20 @@ // If this is a splat elements attribute, simply return the value. All of // the elements of a splat attribute are the same. if (Attribute tensor = adaptor.getTensor()) - if (auto splatTensor = tensor.dyn_cast()) + if (auto splatTensor = llvm::dyn_cast(tensor)) return splatTensor.getSplatValue(); // Collect the constant indices into the tensor. SmallVector indices; for (Attribute indice : adaptor.getIndices()) { - if (!indice || !indice.isa()) + if (!indice || !llvm::isa(indice)) return {}; - indices.push_back(indice.cast().getInt()); + indices.push_back(llvm::cast(indice).getInt()); } // Fold extract(from_elements(...)). if (auto fromElementsOp = getTensor().getDefiningOp()) { - auto tensorType = fromElementsOp.getType().cast(); + auto tensorType = llvm::cast(fromElementsOp.getType()); auto rank = tensorType.getRank(); assert(static_cast(indices.size()) == tensorType.getRank() && "rank mismatch"); @@ -887,7 +888,7 @@ // If this is an elements attribute, query the value at the given indices. if (Attribute tensor = adaptor.getTensor()) { - auto elementsAttr = tensor.dyn_cast(); + auto elementsAttr = llvm::dyn_cast(tensor); if (elementsAttr && elementsAttr.isValidIndex(indices)) return elementsAttr.getValues()[indices]; } @@ -1070,7 +1071,7 @@ LogicalResult InsertOp::verify() { // Verify the # indices match if we have a ranked type. - auto destType = getDest().getType().cast(); + auto destType = llvm::cast(getDest().getType()); if (destType.getRank() != static_cast(getIndices().size())) return emitOpError("incorrect number of indices"); return success(); @@ -1080,7 +1081,7 @@ Attribute scalar = adaptor.getScalar(); Attribute dest = adaptor.getDest(); if (scalar && dest) - if (auto splatDest = dest.dyn_cast()) + if (auto splatDest = llvm::dyn_cast(dest)) if (scalar == splatDest.getSplatValue()) return dest; return {}; @@ -1113,7 +1114,7 @@ LogicalResult GenerateOp::verify() { // Ensure that the tensor type has as many dynamic dimensions as are // specified by the operands. - RankedTensorType resultTy = getType().cast(); + RankedTensorType resultTy = llvm::cast(getType()); if (getNumOperands() != resultTy.getNumDynamicDims()) return emitError("must have as many index operands as dynamic extents " "in the result type"); @@ -1122,7 +1123,7 @@ } LogicalResult GenerateOp::verifyRegions() { - RankedTensorType resultTy = getType().cast(); + RankedTensorType resultTy = llvm::cast(getType()); // Ensure that region arguments span the index space. if (!llvm::all_of(getBody().getArgumentTypes(), [](Type ty) { return ty.isIndex(); })) @@ -1150,7 +1151,7 @@ // Build and populate body. OpBuilder::InsertionGuard guard(b); Region *bodyRegion = result.regions.front().get(); - auto rank = resultTy.cast().getRank(); + auto rank = llvm::cast(resultTy).getRank(); SmallVector argumentTypes(rank, b.getIndexType()); SmallVector argumentLocs(rank, result.location); Block *bodyBlock = @@ -1170,7 +1171,7 @@ LogicalResult matchAndRewrite(GenerateOp tensorFromElements, PatternRewriter &rewriter) const final { auto resultType = - tensorFromElements.getResult().getType().cast(); + llvm::cast(tensorFromElements.getResult().getType()); if (resultType.hasStaticShape()) return failure(); @@ -1261,7 +1262,7 @@ OpFoldResult RankOp::fold(FoldAdaptor adaptor) { // Constant fold rank when the rank of the operand is known. auto type = getOperand().getType(); - auto shapedType = type.dyn_cast(); + auto shapedType = llvm::dyn_cast(type); if (shapedType && shapedType.hasRank()) return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); return IntegerAttr(); @@ -1284,17 +1285,17 @@ } LogicalResult ReshapeOp::verify() { - TensorType operandType = getSource().getType().cast(); - TensorType resultType = getResult().getType().cast(); + TensorType operandType = llvm::cast(getSource().getType()); + TensorType resultType = llvm::cast(getResult().getType()); if (operandType.getElementType() != resultType.getElementType()) return emitOpError("element types of source and destination tensor " "types should be the same"); int64_t shapeSize = - getShape().getType().cast().getDimSize(0); - auto resultRankedType = resultType.dyn_cast(); - auto operandRankedType = operandType.dyn_cast(); + llvm::cast(getShape().getType()).getDimSize(0); + auto resultRankedType = llvm::dyn_cast(resultType); + auto operandRankedType = llvm::dyn_cast(operandType); if (resultRankedType) { if (operandRankedType && resultRankedType.hasStaticShape() && @@ -1392,7 +1393,7 @@ ArrayRef reassociation, ArrayRef attrs) { auto resultType = inferCollapsedType( - src.getType().cast(), + llvm::cast(src.getType()), getSymbolLessAffineMaps( convertReassociationIndicesToExprs(b.getContext(), reassociation))); build(b, result, resultType, src, attrs); @@ -1488,7 +1489,7 @@ if (!fromElements) return failure(); - auto shapedTy = reshapeOp.getType().template cast(); + auto shapedTy = llvm::cast(reshapeOp.getType()); if (!shapedTy.hasStaticShape()) return failure(); @@ -1510,7 +1511,7 @@ return failure(); RankedTensorType srcType = - castOp.getSource().getType().cast(); + llvm::cast(castOp.getSource().getType()); RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType( srcType, collapseShapeOp.getReassociationMaps()); @@ -1693,9 +1694,8 @@ ArrayRef offsets, ArrayRef sizes, ArrayRef strides) { // Type inferred in the absence of rank-reducing behavior. - auto inferredType = - inferResultType(sourceRankedTensorType, offsets, sizes, strides) - .cast(); + auto inferredType = llvm::cast( + inferResultType(sourceRankedTensorType, offsets, sizes, strides)); int rankDiff = inferredType.getRank() - desiredResultRank; if (rankDiff > 0) { auto shape = inferredType.getShape(); @@ -1739,13 +1739,11 @@ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - auto sourceRankedTensorType = source.getType().cast(); + auto sourceRankedTensorType = llvm::cast(source.getType()); // Structuring implementation this way avoids duplication between builders. if (!resultType) { - resultType = - ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets, - staticSizes, staticStrides) - .cast(); + resultType = llvm::cast(ExtractSliceOp::inferResultType( + sourceRankedTensorType, staticOffsets, staticSizes, staticStrides)); } build(b, result, resultType, source, dynamicOffsets, dynamicSizes, dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets), @@ -1831,7 +1829,7 @@ FailureOr ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value, ArrayRef desiredShape) { - auto sourceTensorType = value.getType().dyn_cast(); + auto sourceTensorType = llvm::dyn_cast(value.getType()); assert(sourceTensorType && "not a ranked tensor type"); auto sourceShape = sourceTensorType.getShape(); if (sourceShape.equals(desiredShape)) @@ -1968,8 +1966,8 @@ return failure(); // Dynamic result shape is not supported. - auto sourceType = op.getSource().getType().cast(); - auto resultType = op.getResult().getType().cast(); + auto sourceType = llvm::cast(op.getSource().getType()); + auto resultType = llvm::cast(op.getResult().getType()); if (!sourceType.hasStaticShape() || !resultType.hasStaticShape()) return failure(); @@ -2004,13 +2002,13 @@ // New attribute constructed by the sliced values. DenseElementsAttr newAttr; - if (auto elems = attr.dyn_cast()) { + if (auto elems = llvm::dyn_cast(attr)) { SmallVector outValues; outValues.reserve(sourceType.getNumElements()); sliceElements( elems.begin(), counts, offsets, sizes, strides, &outValues); newAttr = DenseElementsAttr::get(resultType, outValues); - } else if (auto elems = attr.dyn_cast()) { + } else if (auto elems = llvm::dyn_cast(attr)) { SmallVector outValues; outValues.reserve(sourceType.getNumElements()); sliceElements( @@ -2109,7 +2107,7 @@ OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) { if (auto splat = adaptor.getSource().dyn_cast_or_null()) { - auto resultType = getResult().getType().cast(); + auto resultType = llvm::cast(getResult().getType()); if (resultType.hasStaticShape()) return splat.resizeSplat(resultType); } @@ -2124,7 +2122,7 @@ Value mlir::tensor::createCanonicalRankReducingExtractSliceOp( OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) { - auto rankedTensorType = tensor.getType().cast(); + auto rankedTensorType = llvm::cast(tensor.getType()); unsigned rank = rankedTensorType.getRank(); SmallVector offsets(rank, b.getIndexAttr(0)); SmallVector sizes = getMixedSizes(b, loc, tensor); @@ -2372,8 +2370,8 @@ auto src = (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource()); auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest()); - auto srcType = src.getType().template dyn_cast(); - auto dstType = dst.getType().template dyn_cast(); + auto srcType = llvm::dyn_cast(src.getType()); + auto dstType = llvm::dyn_cast(dst.getType()); if (!srcType || !dstType) return failure(); if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(), @@ -2482,7 +2480,7 @@ Location loc, Value tensor, Value dest) { - auto rankedTensorType = dest.getType().cast(); + auto rankedTensorType = llvm::cast(dest.getType()); unsigned rank = rankedTensorType.getRank(); SmallVector offsets(rank, b.getIndexAttr(0)); SmallVector sizes = getMixedSizes(b, loc, dest); @@ -2514,8 +2512,8 @@ } LogicalResult PadOp::verify() { - auto sourceType = getSource().getType().cast(); - auto resultType = getResult().getType().cast(); + auto sourceType = llvm::cast(getSource().getType()); + auto resultType = llvm::cast(getResult().getType()); auto expectedType = PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh()); if (!expectedType) { @@ -2542,7 +2540,7 @@ LogicalResult PadOp::verifyRegions() { auto ®ion = getRegion(); - unsigned rank = getResult().getType().cast().getRank(); + unsigned rank = llvm::cast(getResult().getType()).getRank(); Block &block = region.front(); if (block.getNumArguments() != rank) return emitError("expected the block to have ") << rank << " arguments"; @@ -2557,7 +2555,7 @@ // Ensure that the region yields an element of the right type. auto yieldOp = llvm::cast(block.getTerminator()); if (yieldOp.getValue().getType() != - getType().cast().getElementType()) + llvm::cast(getType()).getElementType()) return emitOpError("expected yield type to match shape element type"); return success(); @@ -2597,7 +2595,7 @@ Value source, ArrayRef staticLow, ArrayRef staticHigh, ValueRange low, ValueRange high, bool nofold, ArrayRef attrs) { - auto sourceType = source.getType().cast(); + auto sourceType = llvm::cast(source.getType()); if (!resultType) resultType = inferResultType(sourceType, staticLow, staticHigh); build(b, result, resultType, source, low, high, @@ -2609,7 +2607,7 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType, Value source, ValueRange low, ValueRange high, bool nofold, ArrayRef attrs) { - auto sourceType = source.getType().cast(); + auto sourceType = llvm::cast(source.getType()); unsigned rank = sourceType.getRank(); SmallVector staticVector(rank, ShapedType::kDynamic); build(b, result, resultType, source, staticVector, staticVector, low, high, @@ -2620,7 +2618,7 @@ Value source, ArrayRef low, ArrayRef high, bool nofold, ArrayRef attrs) { - auto sourceType = source.getType().cast(); + auto sourceType = llvm::cast(source.getType()); SmallVector dynamicLow, dynamicHigh; SmallVector staticLow, staticHigh; // staticLow and staticHigh have full information of the padding config. @@ -2632,7 +2630,7 @@ if (!resultType) { resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh); } - assert(resultType.isa()); + assert(llvm::isa(resultType)); build(b, result, resultType, source, dynamicLow, dynamicHigh, b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr()); @@ -2647,7 +2645,7 @@ // Add a region and a block to yield the pad value. Region *region = result.regions[0].get(); - int sourceRank = source.getType().cast().getRank(); + int sourceRank = llvm::cast(source.getType()).getRank(); SmallVector blockArgTypes(sourceRank, b.getIndexType()); SmallVector blockArgLocs(sourceRank, result.location); @@ -2700,7 +2698,7 @@ return failure(); auto newResultType = PadOp::inferResultType( - castOp.getSource().getType().cast(), + llvm::cast(castOp.getSource().getType()), padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(), padTensorOp.getResultType().getShape()); @@ -2919,9 +2917,9 @@ LogicalResult matchAndRewrite(PadOp padTensorOp, PatternRewriter &rewriter) const override { Value input = padTensorOp.getSource(); - if (!input.getType().isa()) + if (!llvm::isa(input.getType())) return failure(); - auto inputDims = input.getType().cast().getShape(); + auto inputDims = llvm::cast(input.getType()).getShape(); auto inputRank = inputDims.size(); auto oldResultType = @@ -3240,7 +3238,7 @@ "applies to only pack or unpack operations"); int64_t destRank = op.getDestRank(); reifiedReturnShapes.resize(1, SmallVector(destRank)); - ShapedType resultType = op.getResult().getType().template cast(); + ShapedType resultType = llvm::cast(op.getResult().getType()); for (auto dim : llvm::seq(0, destRank)) { if (resultType.isDynamicDim(dim)) { reifiedReturnShapes[0][dim] = @@ -3655,8 +3653,8 @@ }; SmallVector mixedSizes; - for (auto [index, value] : - llvm::enumerate(source.getType().cast().getShape())) { + for (auto [index, value] : llvm::enumerate( + llvm::cast(source.getType()).getShape())) { if (ShapedType::isDynamic(value)) mixedSizes.push_back(b.create(loc, source, index).getResult()); else @@ -3671,7 +3669,7 @@ applyPermutationToVector(mixedSizes, outerDimsPerm); mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end()); - auto elemType = source.getType().cast().getElementType(); + auto elemType = llvm::cast(source.getType()).getElementType(); return b.create(loc, mixedSizes, elemType); } @@ -3789,7 +3787,7 @@ bool PackOp::isLikePad() { auto packedTensorType = - (*this)->getResultTypes().front().cast(); + llvm::cast((*this)->getResultTypes().front()); return isLikePadUnPad(*this, packedTensorType); } @@ -3861,7 +3859,7 @@ }; SmallVector mixedSizes; - auto srcType = source.getType().cast(); + auto srcType = llvm::cast(source.getType()); for (auto i : llvm::seq(0, srcType.getRank() - innerTileSizes.size())) { if (srcType.isDynamicDim(i)) @@ -3944,7 +3942,7 @@ // If no operand comes from a tensor::CastOp and can be folded then fail. bool hasTensorCastOperand = llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) { - if (opOperand.get().isa()) + if (llvm::isa(opOperand.get())) return false; auto castOp = opOperand.get().getDefiningOp(); return castOp && canFoldIntoConsumerOp(castOp); @@ -3961,7 +3959,7 @@ bool fold = canFoldIntoConsumerOp(tensorCastOp); newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get()); if (op.isDpsInit(&opOperand) && - !newOperands.back().getType().isa()) + !llvm::isa(newOperands.back().getType())) newResultTypes.push_back(newOperands.back().getType()); } diff --git a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp @@ -24,8 +24,8 @@ auto castOp = cast(op); assert(value == castOp.getResult() && "invalid value"); - if (castOp.getResult().getType().isa() && - castOp.getSource().getType().isa()) { + if (llvm::isa(castOp.getResult().getType()) && + llvm::isa(castOp.getSource().getType())) { cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim); } } @@ -100,7 +100,8 @@ auto rankOp = cast(op); assert(value == rankOp.getResult() && "invalid value"); - auto tensorType = rankOp.getTensor().getType().dyn_cast(); + auto tensorType = + llvm::dyn_cast(rankOp.getTensor().getType()); if (!tensorType) return; cstr.bound(value) == tensorType.getRank(); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -88,8 +88,8 @@ LogicalResult matchAndRewrite(tosa::ReshapeOp op, PatternRewriter &rewriter) const override { Value input = op.getInput1(); - ShapedType inputTy = input.getType().cast(); - ShapedType resultTy = op.getType().cast(); + ShapedType inputTy = llvm::cast(input.getType()); + ShapedType resultTy = llvm::cast(op.getType()); if (inputTy.getElementType() != resultTy.getElementType()) return rewriter.notifyMatchFailure(op, "element type does not match."); @@ -106,7 +106,7 @@ // Build new const op with correct output shape DenseElementsAttr outputAttr = inputAttr.reshape( - inputAttr.getType().cast().clone(op.getNewShape())); + llvm::cast(inputAttr.getType()).clone(op.getNewShape())); rewriter.replaceOpWithNewOp(op, resultTy, outputAttr); return success(); } @@ -198,7 +198,7 @@ } auto input = op.getInput1(); - auto inputTy = input.getType().cast(); + auto inputTy = llvm::cast(input.getType()); if (!inputTy.hasRank()) return rewriter.notifyMatchFailure(op, "Unranked input."); @@ -255,15 +255,15 @@ auto input = op.getInput1(); auto padding = op.getPadding(); - ShapedType inputTy = input.getType().cast(); + ShapedType inputTy = llvm::cast(input.getType()); Type elementTy = inputTy.getElementType(); Attribute constantAttr; - if (elementTy.isa()) { + if (llvm::isa(elementTy)) { constantAttr = rewriter.getFloatAttr(elementTy, 0.0); - } else if (elementTy.isa() && !op.getQuantizationInfo()) { + } else if (llvm::isa(elementTy) && !op.getQuantizationInfo()) { constantAttr = rewriter.getIntegerAttr(elementTy, 0); - } else if (elementTy.isa() && op.getQuantizationInfo()) { + } else if (llvm::isa(elementTy) && op.getQuantizationInfo()) { auto value = op.getQuantizationInfo()->getInputZp(); constantAttr = rewriter.getIntegerAttr(elementTy, value); } @@ -298,8 +298,8 @@ PatternRewriter &rewriter) const override { Value input = op.getInput(); Value output = op.getOutput(); - ShapedType inputType = input.getType().cast(); - ShapedType outputType = output.getType().cast(); + ShapedType inputType = llvm::cast(input.getType()); + ShapedType outputType = llvm::cast(output.getType()); if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) { return failure(); @@ -332,8 +332,7 @@ LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override { Value input = op.getInput(); - auto inputType = - op.getInput().getType().template dyn_cast(); + auto inputType = llvm::dyn_cast(op.getInput().getType()); auto inputElementType = inputType.getElementType(); if (!inputType.hasStaticShape()) { @@ -373,7 +372,7 @@ return failure(); } - if (inputElementType.isa()) { + if (llvm::isa(inputElementType)) { int64_t minClamp = op.getMinInt(); int64_t maxClamp = op.getMaxInt(); @@ -498,19 +497,19 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy) { if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) { - auto lETy = lhs.getType().cast().getElementType(); - auto rETy = rhs.getType().cast().getElementType(); + auto lETy = llvm::cast(lhs.getType()).getElementType(); + auto rETy = llvm::cast(rhs.getType()).getElementType(); if (lETy != rETy) return {}; - if (lETy.isa()) { + if (llvm::isa(lETy)) { APInt l = lhs.getSplatValue(); APInt r = rhs.getSplatValue(); auto result = IntFolder()(l, r); return DenseElementsAttr::get(returnTy, result); } - if (lETy.isa()) { + if (llvm::isa(lETy)) { APFloat l = lhs.getSplatValue(); APFloat r = rhs.getSplatValue(); auto result = FloatFolder()(l, r); @@ -522,18 +521,18 @@ } static bool isSplatZero(Type elemType, DenseElementsAttr val) { - if (elemType.isa()) + if (llvm::isa(elemType)) return val && val.isSplat() && val.getSplatValue().isZero(); - if (elemType.isa()) + if (llvm::isa(elemType)) return val && val.isSplat() && val.getSplatValue().isZero(); return false; } static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) { - if (elemType.isa()) + if (llvm::isa(elemType)) return val && val.isSplat() && val.getSplatValue().isExactlyValue(1.0); - if (elemType.isa()) { + if (llvm::isa(elemType)) { const int64_t shifted = 1LL << shift; return val && val.isSplat() && val.getSplatValue().getSExtValue() == shifted; @@ -542,9 +541,9 @@ } OpFoldResult AddOp::fold(FoldAdaptor adaptor) { - auto lhsTy = getInput1().getType().dyn_cast(); - auto rhsTy = getInput2().getType().dyn_cast(); - auto resultTy = getType().dyn_cast(); + auto lhsTy = llvm::dyn_cast(getInput1().getType()); + auto rhsTy = llvm::dyn_cast(getInput2().getType()); + auto resultTy = llvm::dyn_cast(getType()); if (!lhsTy || !rhsTy || !resultTy) return {}; @@ -565,9 +564,9 @@ } OpFoldResult DivOp::fold(FoldAdaptor adaptor) { - auto lhsTy = getInput1().getType().dyn_cast(); - auto rhsTy = getInput2().getType().dyn_cast(); - auto resultTy = getType().dyn_cast(); + auto lhsTy = llvm::dyn_cast(getInput1().getType()); + auto rhsTy = llvm::dyn_cast(getInput2().getType()); + auto resultTy = llvm::dyn_cast(getType()); if (!lhsTy || !rhsTy || !resultTy) return {}; if (lhsTy != rhsTy) @@ -577,17 +576,19 @@ auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); if (lhsAttr && lhsAttr.isSplat()) { - if (resultETy.isa() && lhsAttr.getSplatValue().isZero()) + if (llvm::isa(resultETy) && + lhsAttr.getSplatValue().isZero()) return lhsAttr; } if (rhsAttr && rhsAttr.isSplat()) { - if (resultETy.isa() && rhsAttr.getSplatValue().isOne()) + if (llvm::isa(resultETy) && + rhsAttr.getSplatValue().isOne()) return getInput1(); } if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) { - if (resultETy.isa()) { + if (llvm::isa(resultETy)) { APInt l = lhsAttr.getSplatValue(); APInt r = rhsAttr.getSplatValue(); APInt result = l.sdiv(r); @@ -602,7 +603,7 @@ DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType ty, int32_t shift) { if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) { - if (ty.getElementType().isa()) { + if (llvm::isa(ty.getElementType())) { APInt l = lhs.getSplatValue(); APInt r = rhs.getSplatValue(); @@ -619,7 +620,7 @@ return DenseElementsAttr::get(ty, result); } - if (ty.getElementType().isa()) { + if (llvm::isa(ty.getElementType())) { APFloat l = lhs.getSplatValue(); APFloat r = rhs.getSplatValue(); APFloat result = l * r; @@ -634,9 +635,9 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) { auto lhs = getInput1(); auto rhs = getInput2(); - auto lhsTy = lhs.getType().dyn_cast(); - auto rhsTy = rhs.getType().dyn_cast(); - auto resultTy = getType().dyn_cast(); + auto lhsTy = llvm::dyn_cast(lhs.getType()); + auto rhsTy = llvm::dyn_cast(rhs.getType()); + auto resultTy = llvm::dyn_cast(getType()); if (!lhsTy || !rhsTy || !resultTy) return {}; @@ -644,7 +645,7 @@ auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); - const int64_t shift = resultETy.isa() ? getShift() : 0; + const int64_t shift = llvm::isa(resultETy) ? getShift() : 0; if (rhsTy == resultTy) { if (isSplatZero(resultETy, lhsAttr)) return lhsAttr; @@ -662,9 +663,9 @@ } OpFoldResult SubOp::fold(FoldAdaptor adaptor) { - auto lhsTy = getInput1().getType().dyn_cast(); - auto rhsTy = getInput2().getType().dyn_cast(); - auto resultTy = getType().dyn_cast(); + auto lhsTy = llvm::dyn_cast(getInput1().getType()); + auto rhsTy = llvm::dyn_cast(getInput2().getType()); + auto resultTy = llvm::dyn_cast(getType()); if (!lhsTy || !rhsTy || !resultTy) return {}; @@ -711,7 +712,7 @@ } // namespace OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) { - auto resultTy = getType().dyn_cast(); + auto resultTy = llvm::dyn_cast(getType()); auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); @@ -723,7 +724,7 @@ } OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) { - auto resultTy = getType().dyn_cast(); + auto resultTy = llvm::dyn_cast(getType()); auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); @@ -736,16 +737,16 @@ } OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { - auto resultTy = getType().dyn_cast(); + auto resultTy = llvm::dyn_cast(getType()); auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); Value lhs = getInput1(); Value rhs = getInput2(); - auto lhsTy = lhs.getType().cast(); + auto lhsTy = llvm::cast(lhs.getType()); // If we are comparing an integer value to itself it is always true. We can // not do this with float due to float values. - if (lhsTy.getElementType().isa() && resultTy && + if (llvm::isa(lhsTy.getElementType()) && resultTy && resultTy.hasStaticShape() && lhs == rhs) { return DenseElementsAttr::get(resultTy, true); } @@ -766,41 +767,41 @@ if (!operand) return {}; - auto inTy = getInput().getType().cast(); - auto outTy = getType().cast(); + auto inTy = llvm::cast(getInput().getType()); + auto outTy = llvm::cast(getType()); auto inETy = inTy.getElementType(); auto outETy = outTy.getElementType(); if (operand.isSplat()) { - if (inETy.isa() && outETy.isa()) { + if (llvm::isa(inETy) && llvm::isa(outETy)) { bool overflow; auto splatVal = operand.getSplatValue(); - auto &semantics = outETy.cast().getFloatSemantics(); + auto &semantics = llvm::cast(outETy).getFloatSemantics(); splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven, &overflow); return SplatElementsAttr::get(outTy, splatVal); } - if (inETy.isa() && outETy.isa()) { - auto unsign = inETy.cast().isUnsignedInteger(); - APFloat splatVal(outETy.cast().getFloatSemantics()); + if (llvm::isa(inETy) && llvm::isa(outETy)) { + auto unsign = llvm::cast(inETy).isUnsignedInteger(); + APFloat splatVal(llvm::cast(outETy).getFloatSemantics()); splatVal.convertFromAPInt(operand.getSplatValue(), !unsign, llvm::RoundingMode::NearestTiesToEven); return SplatElementsAttr::get(outTy, splatVal); } - if (inETy.isa() && outETy.isa()) { - auto unsign = outETy.cast().isUnsignedInteger(); - auto intVal = - APSInt(outETy.cast().getIntOrFloatBitWidth(), unsign); + if (llvm::isa(inETy) && llvm::isa(outETy)) { + auto unsign = llvm::cast(outETy).isUnsignedInteger(); + auto intVal = APSInt( + llvm::cast(outETy).getIntOrFloatBitWidth(), unsign); auto floatVal = operand.getSplatValue(); bool exact; floatVal.convertToInteger(intVal, llvm::RoundingMode::TowardZero, &exact); return SplatElementsAttr::get(outTy, intVal); } - if (inETy.isa() && outETy.isa()) { - auto unsignIn = inETy.cast().isUnsignedInteger(); + if (llvm::isa(inETy) && llvm::isa(outETy)) { + auto unsignIn = llvm::cast(inETy).isUnsignedInteger(); bool trunc = inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth(); auto intVal = operand.getSplatValue(); @@ -842,8 +843,8 @@ #undef REDUCE_FOLDER OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { - auto inputTy = getInput1().getType().dyn_cast(); - auto outputTy = getType().dyn_cast(); + auto inputTy = llvm::dyn_cast(getInput1().getType()); + auto outputTy = llvm::dyn_cast(getType()); if (!inputTy || !outputTy) return {}; @@ -894,8 +895,8 @@ } auto input = getInput(); - auto inputTy = input.getType().cast(); - auto resultTy = getType().cast(); + auto inputTy = llvm::cast(input.getType()); + auto resultTy = llvm::cast(getType()); if (inputTy != resultTy) return {}; @@ -904,7 +905,7 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) { auto operand = getInput(); - auto operandTy = operand.getType().cast(); + auto operandTy = llvm::cast(operand.getType()); auto axis = getAxis(); auto operandAttr = adaptor.getInput().dyn_cast_or_null(); if (operandAttr) @@ -918,8 +919,8 @@ } OpFoldResult SliceOp::fold(FoldAdaptor adaptor) { - auto inputTy = getInput().getType().dyn_cast(); - auto outputTy = getType().dyn_cast(); + auto inputTy = llvm::dyn_cast(getInput().getType()); + auto outputTy = llvm::dyn_cast(getType()); if (!inputTy || !outputTy) return {}; @@ -972,8 +973,8 @@ } OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { - auto inputTy = getInput1().getType().cast(); - auto resultTy = getType().cast(); + auto inputTy = llvm::cast(getInput1().getType()); + auto resultTy = llvm::cast(getType()); // Transposing splat values just means reshaping. if (auto input = adaptor.getInput1().dyn_cast_or_null()) { diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -90,8 +90,9 @@ Type type, Location loc) { // Tosa dialect constants only support ElementsAttr unlike standard dialect // constant which supports all attributes. - if (value.isa()) - return builder.create(loc, type, value.cast()); + if (llvm::isa(value)) + return builder.create(loc, type, + llvm::cast(value)); return nullptr; } @@ -101,10 +102,8 @@ template static LogicalResult verifyConvOp(T op) { // All TOSA conv ops have an input() and weight(). - auto inputType = - op.getInput().getType().template dyn_cast(); - auto weightType = - op.getWeight().getType().template dyn_cast(); + auto inputType = llvm::dyn_cast(op.getInput().getType()); + auto weightType = llvm::dyn_cast(op.getWeight().getType()); // Must be ranked tensor types if (!inputType) { @@ -119,8 +118,8 @@ auto inputEType = inputType.getElementType(); auto weightEType = weightType.getElementType(); - bool inputIsQuant = !inputEType.template isa(); - bool weightIsQuant = !weightEType.template isa(); + bool inputIsQuant = !llvm::isa(inputEType); + bool weightIsQuant = !llvm::isa(weightEType); // Either both must be quantized or both unquantized. if (inputIsQuant != weightIsQuant) { @@ -143,13 +142,15 @@ } LogicalResult tosa::AvgPool2dOp::verify() { - auto inputETy = getInput().getType().cast().getElementType(); - auto resultETy = getType().cast().getElementType(); + auto inputETy = llvm::cast(getInput().getType()).getElementType(); + auto resultETy = llvm::cast(getType()).getElementType(); - if (auto quantType = inputETy.dyn_cast()) + if (auto quantType = + llvm::dyn_cast(inputETy)) inputETy = quantType.getStorageType(); - if (auto quantType = resultETy.dyn_cast()) + if (auto quantType = + llvm::dyn_cast(resultETy)) resultETy = quantType.getStorageType(); if (inputETy.isF32() && resultETy.isF32()) @@ -240,16 +241,16 @@ if (quantAttr) { result.addAttribute("quantization_info", quantAttr); - auto inputType = a.getType().dyn_cast(); + auto inputType = llvm::dyn_cast(a.getType()); assert(inputType && "Input must be a shaped tensor type!"); - auto inputQType = inputType.getElementType() - .dyn_cast(); + auto inputQType = llvm::dyn_cast( + inputType.getElementType()); assert(inputQType && "Tensor must have quantized datatype!"); unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); - auto outputShapedType = outputType.dyn_cast(); + auto outputShapedType = llvm::dyn_cast(outputType); assert(outputShapedType && "Output must be a shaped type"); IntegerType accElementType; @@ -368,7 +369,7 @@ OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape = operands.getShape(0); - IntegerAttr axis = attributes.get("axis").cast(); + IntegerAttr axis = llvm::cast(attributes.get("axis")); int32_t axisVal = axis.getValue().getSExtValue(); if (!inputShape.hasRank()) { @@ -432,7 +433,7 @@ SmallVectorImpl &inferredReturnShapes) { // Infer all dimension sizes by reducing based on inputs. int32_t axis = - attributes.get("axis").cast().getValue().getSExtValue(); + llvm::cast(attributes.get("axis")).getValue().getSExtValue(); llvm::SmallVector outputShape; bool hasRankedInput = false; for (auto operand : operands) { @@ -459,7 +460,8 @@ hasRankedInput = true; } - Type inputType = operands.getType()[0].cast().getElementType(); + Type inputType = + llvm::cast(operands.getType()[0]).getElementType(); if (!hasRankedInput) { inferredReturnShapes.push_back(ShapedTypeComponents(inputType)); return success(); @@ -738,8 +740,8 @@ } mlir::LogicalResult tosa::ReshapeOp::verify() { - ShapedType inputType = getInput1().getType().cast(); - ShapedType outputType = getType().cast(); + ShapedType inputType = llvm::cast(getInput1().getType()); + ShapedType outputType = llvm::cast(getType()); if (inputType.hasStaticShape() && outputType.hasStaticShape()) { int64_t inputElementsNum = inputType.getNumElements(); @@ -1064,9 +1066,11 @@ int64_t height = inputShape.getDimSize(1); int64_t width = inputShape.getDimSize(2); - ArrayRef kernel = attributes.get("kernel").cast(); - ArrayRef stride = attributes.get("stride").cast(); - ArrayRef pad = attributes.get("pad").cast(); + ArrayRef kernel = + llvm::cast(attributes.get("kernel")); + ArrayRef stride = + llvm::cast(attributes.get("stride")); + ArrayRef pad = llvm::cast(attributes.get("pad")); if (!ShapedType::isDynamic(height)) { int64_t padded = height + pad[0] + pad[1] - kernel[0]; @@ -1473,7 +1477,7 @@ } std::optional> ApplyScaleOp::getShapeForUnroll() { - if (auto vt = getType().dyn_cast()) + if (auto vt = llvm::dyn_cast(getType())) return llvm::to_vector<4>(vt.getShape()); return std::nullopt; } diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -169,7 +169,7 @@ return success(); } if (attribute.getName().getValue() == kTargetTagAttrName) { - if (!attribute.getValue().isa()) { + if (!llvm::isa(attribute.getValue())) { return op->emitError() << attribute.getName() << " attribute must be a string"; } @@ -177,7 +177,7 @@ } if (attribute.getName().getValue() == kArgConsumedAttrName || attribute.getName().getValue() == kArgReadOnlyAttrName) { - if (!attribute.getValue().isa()) { + if (!llvm::isa(attribute.getValue())) { return op->emitError() << attribute.getName() << " must be a unit attribute"; } diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -114,7 +114,7 @@ function_ref)> operationsFn, function_ref)> paramsFn, function_ref valuesFn) { - if (handle.getType().isa()) { + if (llvm::isa(handle.getType())) { SmallVector operations; operations.reserve(values.size()); for (transform::MappedValue value : values) { @@ -130,7 +130,8 @@ return DiagnosedSilenceableFailure::success(); } - if (handle.getType().isa()) { + if (llvm::isa( + handle.getType())) { SmallVector payloadValues; payloadValues.reserve(values.size()); for (transform::MappedValue value : values) { @@ -146,7 +147,7 @@ return DiagnosedSilenceableFailure::success(); } - assert(handle.getType().isa() && + assert(llvm::isa(handle.getType()) && "unsupported kind of block argument"); SmallVector parameters; parameters.reserve(values.size()); @@ -185,7 +186,7 @@ ArrayRef targets) { assert(value != kTopLevelValue && "attempting to reset the transformation root"); - assert(value.getType().isa() && + assert(llvm::isa(value.getType()) && "wrong handle type"); for (Operation *target : targets) { @@ -195,7 +196,7 @@ << "attempting to assign a null payload op to this transform value"; } - auto iface = value.getType().cast(); + auto iface = llvm::cast(value.getType()); DiagnosedSilenceableFailure result = iface.checkPayload(value.getLoc(), targets); if (failed(result.checkAndReport())) @@ -220,7 +221,7 @@ transform::TransformState::setPayloadValues(Value handle, ValueRange payloadValues) { assert(handle != nullptr && "attempting to set params for a null value"); - assert(handle.getType().isa() && + assert(llvm::isa(handle.getType()) && "wrong handle type"); for (Value payload : payloadValues) { @@ -230,7 +231,7 @@ "value to this transform handle"; } - auto iface = handle.getType().cast(); + auto iface = llvm::cast(handle.getType()); SmallVector payloadValueVector = llvm::to_vector(payloadValues); DiagnosedSilenceableFailure result = iface.checkPayload(handle.getLoc(), payloadValueVector); @@ -262,7 +263,7 @@ << "attempting to assign a null parameter to this transform value"; } - auto valueType = value.getType().dyn_cast(); + auto valueType = llvm::dyn_cast(value.getType()); assert(value && "cannot associate parameter with a value of non-parameter type"); DiagnosedSilenceableFailure result = @@ -497,11 +498,11 @@ Operation *definingOp; std::optional resultNo; unsigned argumentNo, blockNo, regionNo; - if (auto opResult = payloadValue.dyn_cast()) { + if (auto opResult = llvm::dyn_cast(payloadValue)) { definingOp = opResult.getOwner(); resultNo = opResult.getResultNumber(); } else { - auto arg = payloadValue.cast(); + auto arg = llvm::cast(payloadValue); definingOp = arg.getParentBlock()->getParentOp(); argumentNo = arg.getArgNumber(); blockNo = std::distance(arg.getOwner()->getParent()->begin(), @@ -602,11 +603,11 @@ }; } - if (auto opResult = payloadValue.dyn_cast()) { + if (auto opResult = llvm::dyn_cast(payloadValue)) { Operation *payloadOp = opResult.getOwner(); recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue); } else { - auto arg = payloadValue.dyn_cast(); + auto arg = llvm::dyn_cast(payloadValue); for (Operation &payloadOp : *arg.getOwner()) recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue); } @@ -642,13 +643,12 @@ }; if (llvm::any_of(effects, consumesTarget)) { FULL_LDBG("----found consume effect -> SKIP\n"); - if (target.get().getType().isa()) { + if (llvm::isa(target.get().getType())) { FULL_LDBG("----recordOpHandleInvalidation\n"); ArrayRef payloadOps = getPayloadOps(target.get()); recordOpHandleInvalidation(target, payloadOps); - } else if (target.get() - .getType() - .isa()) { + } else if (llvm::isa( + target.get().getType())) { FULL_LDBG("----recordValueHandleInvalidation\n"); recordValueHandleInvalidation(target); } else { @@ -717,7 +717,7 @@ FULL_LDBG("--handle is consumed\n"); Type operandType = operand.get().getType(); - if (operandType.isa()) { + if (llvm::isa(operandType)) { FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n"); DiagnosedSilenceableFailure check = checkRepeatedConsumptionInOperand( @@ -727,7 +727,7 @@ FULL_LDBG("----FAILED\n"); return check; } - } else if (operandType.isa()) { + } else if (llvm::isa(operandType)) { FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n"); DiagnosedSilenceableFailure check = checkRepeatedConsumptionInOperand( @@ -794,7 +794,7 @@ #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS for (unsigned index : consumedOperands) { Value operand = transform->getOperand(index); - if (operand.getType().isa()) { + if (llvm::isa(operand.getType())) { for (Operation *payloadOp : getPayloadOps(operand)) { llvm::append_range(origOpFlatResults, payloadOp->getResults()); #if LLVM_ENABLE_ABI_BREAKING_CHECKS @@ -808,15 +808,15 @@ } continue; } - if (operand.getType().isa()) { + if (llvm::isa(operand.getType())) { for (Value payloadValue : getPayloadValues(operand)) { - if (payloadValue.isa()) { + if (llvm::isa(payloadValue)) { origAssociatedOps.push_back(payloadValue.getDefiningOp()); continue; } llvm::append_range( origAssociatedOps, - llvm::map_range(*payloadValue.cast().getOwner(), + llvm::map_range(*llvm::cast(payloadValue).getOwner(), [](Operation &op) { return &op; })); } continue; @@ -847,9 +847,10 @@ // allows us to catch use-after-free with assertions later on. for (unsigned index : consumedOperands) { Value operand = transform->getOperand(index); - if (operand.getType().isa()) { + if (llvm::isa(operand.getType())) { forgetMapping(operand, origOpFlatResults); - } else if (operand.getType().isa()) { + } else if (llvm::isa( + operand.getType())) { forgetValueMapping(operand, origAssociatedOps); } } @@ -923,14 +924,14 @@ LogicalResult transform::TransformState::updateStateFromResults( const TransformResults &results, ResultRange opResults) { for (OpResult result : opResults) { - if (result.getType().isa()) { + if (llvm::isa(result.getType())) { assert(results.isParam(result.getResultNumber()) && "expected parameters for the parameter-typed result"); if (failed( setParams(result, results.getParams(result.getResultNumber())))) { return failure(); } - } else if (result.getType().isa()) { + } else if (llvm::isa(result.getType())) { assert(results.isValue(result.getResultNumber()) && "expected values for value-type-result"); if (failed(setPayloadValues( @@ -1137,19 +1138,19 @@ llvm::zip(partialResult, transformOp->getResults())) { if (ptr.isNull()) continue; - if (res.getType().template isa() && + if (llvm::isa(res.getType()) && !ptr.is()) { return emitDiag() << "application of " << transformOpName << " expected to produce an Operation * for result #" << res.getResultNumber(); } - if (res.getType().template isa() && + if (llvm::isa(res.getType()) && !ptr.is()) { return emitDiag() << "application of " << transformOpName << " expected to produce an Attribute for result #" << res.getResultNumber(); } - if (res.getType().template isa() && + if (llvm::isa(res.getType()) && !ptr.is()) { return emitDiag() << "application of " << transformOpName << " expected to produce a Value for result #" @@ -1182,10 +1183,10 @@ for (OpResult r : transformOp->getResults()) { unsigned position = r.getResultNumber(); - if (r.getType().isa()) { + if (llvm::isa(r.getType())) { transformResults.setParams(r, castVector(transposed[position])); - } else if (r.getType().isa()) { + } else if (llvm::isa(r.getType())) { transformResults.setValues(r, castVector(transposed[position])); } else { transformResults.set(r, castVector(transposed[position])); @@ -1202,12 +1203,13 @@ ValueRange values, const transform::TransformState &state) { for (Value operand : values) { SmallVector &mapped = mappings.emplace_back(); - if (operand.getType().isa()) { + if (llvm::isa(operand.getType())) { llvm::append_range(mapped, state.getPayloadOps(operand)); - } else if (operand.getType().isa()) { + } else if (llvm::isa( + operand.getType())) { llvm::append_range(mapped, state.getPayloadValues(operand)); } else { - assert(operand.getType().isa() && + assert(llvm::isa(operand.getType()) && "unsupported kind of transform dialect value"); llvm::append_range(mapped, state.getParams(operand)); } @@ -1220,14 +1222,15 @@ for (auto &&[terminatorOperand, result] : llvm::zip(block->getTerminator()->getOperands(), block->getParentOp()->getOpResults())) { - if (result.getType().isa()) { + if (llvm::isa(result.getType())) { results.set(result, state.getPayloadOps(terminatorOperand)); - } else if (result.getType() - .isa()) { + } else if (llvm::isa( + result.getType())) { results.setValues(result, state.getPayloadValues(terminatorOperand)); } else { - assert(result.getType().isa() && - "unhandled transform type interface"); + assert( + llvm::isa(result.getType()) && + "unhandled transform type interface"); results.setParams(result, state.getParams(terminatorOperand)); } } @@ -1291,7 +1294,8 @@ return op->emitOpError() << "expects the entry block to have at least one argument"; } - if (!body->getArgument(0).getType().isa()) { + if (!llvm::isa( + body->getArgument(0).getType())) { return op->emitOpError() << "expects the first entry block argument to be of type " "implementing TransformHandleTypeInterface"; @@ -1305,9 +1309,8 @@ } } for (BlockArgument arg : body->getArguments().drop_front()) { - if (arg.getType() - .isa()) + if (llvm::isa(arg.getType())) continue; InFlightDiagnostic diag = @@ -1344,9 +1347,8 @@ bool hasPayloadOperands = false; for (Value operand : op->getOperands()) { onlyReadsHandle(operand, effects); - if (operand.getType() - .isa()) + if (llvm::isa(operand.getType())) hasPayloadOperands = true; } if (hasPayloadOperands) @@ -1364,7 +1366,7 @@ op->getName().getStringRef()); } for (Value result : op->getResults()) { - if (result.getType().isa()) + if (llvm::isa(result.getType())) continue; return op->emitOpError() << "ParamProducerTransformOpTrait attached to this op expects " diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -440,8 +440,8 @@ return llvm::all_of( std::initializer_list{inputs.front(), outputs.front()}, [](Type ty) { - return ty - .isa(); + return llvm::isa(ty); }); } @@ -563,7 +563,8 @@ // the payload to the result. Note that we need to consume the root handle to // make sure any handles to operations inside, that could have been affected // by actions, are invalidated. - results.set(getUpdated().cast(), state.getPayloadOps(getRoot())); + results.set(llvm::cast(getUpdated()), + state.getPayloadOps(getRoot())); return DiagnosedSilenceableFailure::success(); } @@ -810,7 +811,7 @@ } for (unsigned i = 0; i < getNumResults(); ++i) - results.set(getResult(i).cast(), resultOps[i]); + results.set(llvm::cast(getResult(i)), resultOps[i]); return DiagnosedSilenceableFailure::success(); } @@ -863,7 +864,7 @@ return emitOpError() << "expects the same number of results as the " "terminator has operands"; for (Value v : yieldOp.getOperands()) - if (!v.getType().isa()) + if (!llvm::isa(v.getType())) return yieldOp->emitOpError("expects operands to have types implementing " "TransformHandleTypeInterface"); return success(); @@ -888,7 +889,7 @@ } parents.insert(parent); } - results.set(getResult().cast(), parents.getArrayRef()); + results.set(llvm::cast(getResult()), parents.getArrayRef()); return DiagnosedSilenceableFailure::success(); } @@ -902,7 +903,7 @@ int64_t resultNumber = getResultNumber(); ArrayRef payloadOps = state.getPayloadOps(getTarget()); if (payloadOps.empty()) { - results.set(getResult().cast(), {}); + results.set(llvm::cast(getResult()), {}); return DiagnosedSilenceableFailure::success(); } if (payloadOps.size() != 1) @@ -912,7 +913,7 @@ Operation *target = payloadOps.front(); if (target->getNumResults() <= resultNumber) return emitDefiniteFailure() << "result number overflow"; - results.set(getResult().cast(), + results.set(llvm::cast(getResult()), llvm::to_vector(target->getResult(resultNumber).getUsers())); return DiagnosedSilenceableFailure::success(); } @@ -926,7 +927,7 @@ transform::TransformState &state) { SmallVector definingOps; for (Value v : state.getPayloadValues(getTarget())) { - if (v.isa()) { + if (llvm::isa(v)) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "cannot get defining op of block argument"; diag.attachNote(v.getLoc()) << "target value"; @@ -934,7 +935,7 @@ } definingOps.push_back(v.getDefiningOp()); } - results.set(getResult().cast(), definingOps); + results.set(llvm::cast(getResult()), definingOps); return DiagnosedSilenceableFailure::success(); } @@ -962,7 +963,7 @@ } producers.push_back(producer); } - results.set(getResult().cast(), producers); + results.set(llvm::cast(getResult()), producers); return DiagnosedSilenceableFailure::success(); } @@ -984,7 +985,7 @@ } opResults.push_back(target->getOpResult(resultNumber)); } - results.setValues(getResult().cast(), opResults); + results.setValues(llvm::cast(getResult()), opResults); return DiagnosedSilenceableFailure::success(); } @@ -1211,8 +1212,8 @@ } for (auto &&[i, param, reference] : llvm::enumerate(params, references)) { - auto intAttr = param.dyn_cast(); - auto refAttr = reference.dyn_cast(); + auto intAttr = llvm::dyn_cast(param); + auto refAttr = llvm::dyn_cast(reference); if (!intAttr || !refAttr) { return emitDefiniteFailure() << "non-integer parameter value not expected"; @@ -1295,12 +1296,12 @@ for (Value operand : getHandles()) llvm::append_range(operations, state.getPayloadOps(operand)); if (!getDeduplicate()) { - results.set(getResult().cast(), operations); + results.set(llvm::cast(getResult()), operations); return DiagnosedSilenceableFailure::success(); } SetVector uniqued(operations.begin(), operations.end()); - results.set(getResult().cast(), uniqued.getArrayRef()); + results.set(llvm::cast(getResult()), uniqued.getArrayRef()); return DiagnosedSilenceableFailure::success(); } @@ -1535,7 +1536,7 @@ // Set transform op results. for (auto &&it : llvm::enumerate(resultHandles)) - results.set(getResult(it.index()).cast(), it.value()); + results.set(llvm::cast(getResult(it.index())), it.value()); return DiagnosedSilenceableFailure::success(); } @@ -1573,7 +1574,7 @@ << "could not find pattern '" << getPatternName() << "'"; } } - results.set(getResult().cast(), targets); + results.set(llvm::cast(getResult()), targets); return DiagnosedSilenceableFailure::success(); } @@ -1594,22 +1595,23 @@ unsigned numRepetitions = state.getPayloadOps(getPattern()).size(); for (const auto &en : llvm::enumerate(getHandles())) { Value handle = en.value(); - if (handle.getType().isa()) { + if (llvm::isa(handle.getType())) { ArrayRef current = state.getPayloadOps(handle); SmallVector payload; payload.reserve(numRepetitions * current.size()); for (unsigned i = 0; i < numRepetitions; ++i) llvm::append_range(payload, current); - results.set(getReplicated()[en.index()].cast(), payload); + results.set(llvm::cast(getReplicated()[en.index()]), payload); } else { - assert(handle.getType().isa() && + assert(llvm::isa(handle.getType()) && "expected param type"); ArrayRef current = state.getParams(handle); SmallVector params; params.reserve(numRepetitions * current.size()); for (unsigned i = 0; i < numRepetitions; ++i) llvm::append_range(params, current); - results.setParams(getReplicated()[en.index()].cast(), params); + results.setParams(llvm::cast(getReplicated()[en.index()]), + params); } } return DiagnosedSilenceableFailure::success(); diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp @@ -75,7 +75,7 @@ LogicalResult transform::ParamType::verify(function_ref emitError, Type type) { - IntegerType intType = type.dyn_cast(); + IntegerType intType = llvm::dyn_cast(type); if (!intType || intType.getWidth() > 64) return emitError() << "only supports integer types with width <=64"; return success(); @@ -85,7 +85,7 @@ transform::ParamType::checkPayload(Location loc, ArrayRef payload) const { for (Attribute attr : payload) { - auto integerAttr = attr.dyn_cast(); + auto integerAttr = llvm::dyn_cast(attr); if (!integerAttr) { return emitSilenceableError(loc) << "expected parameter to be an integer attribute, got " << attr; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -65,7 +65,7 @@ // Inspect constant dense values. We count up for bits that // are set, count down for bits that are cleared, and bail // when a mix is detected. - if (auto denseElts = c.getValue().dyn_cast()) { + if (auto denseElts = llvm::dyn_cast(c.getValue())) { int64_t val = 0; for (bool b : denseElts.getValues()) if (b && val >= 0) @@ -88,7 +88,7 @@ bool allTrue = true; bool allFalse = true; for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) { - int64_t i = maskIdx.cast().getInt(); + int64_t i = llvm::cast(maskIdx).getInt(); if (i < dimSize) allTrue = false; if (i > 0) @@ -125,7 +125,7 @@ return elementType.isIntOrIndex(); case CombiningKind::MINF: case CombiningKind::MAXF: - return elementType.isa(); + return llvm::isa(elementType); } return false; } @@ -143,7 +143,7 @@ VectorType vectorType) { int64_t elementVectorRank = 0; VectorType elementVectorType = - shapedType.getElementType().dyn_cast(); + llvm::dyn_cast(shapedType.getElementType()); if (elementVectorType) elementVectorRank += elementVectorType.getRank(); // 0-d transfers are to/from tensor/memref and vector<1xt>. @@ -190,15 +190,15 @@ if (i < rankOffset) { // For leading dimensions, if we can prove that index are different we // know we are accessing disjoint slices. - if (indexA.getValue().cast().getInt() != - indexB.getValue().cast().getInt()) + if (llvm::cast(indexA.getValue()).getInt() != + llvm::cast(indexB.getValue()).getInt()) return true; } else { // For this dimension, we slice a part of the memref we need to make sure // the intervals accessed don't overlap. int64_t distance = - std::abs(indexA.getValue().cast().getInt() - - indexB.getValue().cast().getInt()); + std::abs(llvm::cast(indexA.getValue()).getInt() - + llvm::cast(indexB.getValue()).getInt()); if (distance >= transferA.getVectorType().getDimSize(i - rankOffset)) return true; } @@ -325,7 +325,7 @@ Type inferredReturnType; for (auto it : llvm::enumerate(getSourceVectorType().getShape())) if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) { - return attr.cast().getValue() == it.index(); + return llvm::cast(attr).getValue() == it.index(); })) targetShape.push_back(it.value()); // TODO: update to also allow 0-d vectors when available. @@ -426,8 +426,9 @@ void vector::ReductionOp::build(OpBuilder &builder, OperationState &result, CombiningKind kind, Value vector, Value acc) { - build(builder, result, vector.getType().cast().getElementType(), - kind, vector, acc); + build(builder, result, + llvm::cast(vector.getType()).getElementType(), kind, vector, + acc); } LogicalResult ReductionOp::verify() { @@ -659,9 +660,8 @@ // because tests still use the old format when 'iterator_types' attribute is // represented as an array of strings. // TODO: Remove this conversion once tests are fixed. - ArrayAttr iteratorTypes = - result.attributes.get(getIteratorTypesAttrName(result.name)) - .cast(); + ArrayAttr iteratorTypes = llvm::cast( + result.attributes.get(getIteratorTypesAttrName(result.name))); SmallVector iteratorTypeAttrs; @@ -687,8 +687,8 @@ if (masksInfo.size() != 2) return parser.emitError(parser.getNameLoc(), "expected zero or exactly 2 vector mask operands"); - auto lhsType = types[0].cast(); - auto rhsType = types[1].cast(); + auto lhsType = llvm::cast(types[0]); + auto rhsType = llvm::cast(types[1]); auto maskElementType = parser.getBuilder().getI1Type(); std::array maskTypes = { VectorType::Builder(lhsType).setElementType(maskElementType), @@ -707,8 +707,7 @@ for (auto attr : (*this)->getAttrs()) { if (attr.getName() == getIteratorTypesAttrName()) { auto iteratorTypes = - attr.getValue() - .cast() + llvm::cast(attr.getValue()) .getAsValueRange(); // Convert IteratorType enums into the string representation. This is // needed, because tests still use the old format when 'iterator_types' @@ -778,12 +777,12 @@ // Verify 'expectedResultDims'. if (expectedResultDims.empty()) { // No batch or free dimension implies a scalar result. - if (resType.isa() || accType.isa()) + if (llvm::isa(resType) || llvm::isa(accType)) return op.emitOpError("invalid accumulator/result vector shape"); } else { // At least one batch or free dimension implies a vector result. - auto resVectorType = resType.dyn_cast(); - auto accVectorType = accType.dyn_cast(); + auto resVectorType = llvm::dyn_cast(resType); + auto accVectorType = llvm::dyn_cast(accType); if (!resVectorType || !accVectorType) return op.emitOpError("invalid accumulator/result vector shape"); @@ -841,7 +840,7 @@ Type accType = getAccType(); Type resType = getResultType(); - if (lhsType.getElementType().isa()) { + if (llvm::isa(lhsType.getElementType())) { if (!lhsType.getElementType().isSignlessInteger()) return emitOpError("only supports signless integer types"); } @@ -860,7 +859,7 @@ if (map.getNumSymbols() != 0) return emitOpError("expected indexing map ") << index << " to have no symbols"; - auto vectorType = getOperand(index).getType().dyn_cast(); + auto vectorType = llvm::dyn_cast(getOperand(index).getType()); unsigned rank = vectorType ? vectorType.getShape().size() : 0; // Verify that the map has the right number of inputs, outputs, and indices. // This also correctly accounts for (..) -> () for rank-0 results. @@ -896,7 +895,7 @@ return failure(); // Verify supported combining kind. - auto vectorType = resType.dyn_cast(); + auto vectorType = llvm::dyn_cast(resType); auto elementType = vectorType ? vectorType.getElementType() : resType; if (!isSupportedCombiningKind(getKind(), elementType)) return emitOpError("unsupported contraction type"); @@ -949,7 +948,7 @@ IteratorType targetIteratorType, MLIRContext *context) { std::vector> dimMap; for (const auto &it : llvm::enumerate(iteratorTypes)) { - auto iteratorType = it.value().cast().getValue(); + auto iteratorType = llvm::cast(it.value()).getValue(); if (iteratorType != targetIteratorType) continue; // Search lhs/rhs map results for 'targetExpr'. @@ -965,13 +964,13 @@ void ContractionOp::getIterationBounds( SmallVectorImpl &iterationBounds) { auto lhsShape = getLhsType().getShape(); - auto resVectorType = getResultType().dyn_cast(); + auto resVectorType = llvm::dyn_cast(getResultType()); SmallVector indexingMaps(getIndexingMapsArray()); SmallVector iterationShape; for (const auto &it : llvm::enumerate(getIteratorTypes())) { // Search lhs/rhs map results for 'targetExpr'. auto targetExpr = getAffineDimExpr(it.index(), getContext()); - auto iteratorType = it.value().cast().getValue(); + auto iteratorType = llvm::cast(it.value()).getValue(); if (iteratorType == IteratorType::reduction) { // Get reduction dim size from lhs shape (same size in rhsShape). int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr); @@ -1085,7 +1084,7 @@ void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, Value source) { result.addOperands({source}); - result.addTypes(source.getType().cast().getElementType()); + result.addTypes(llvm::cast(source.getType()).getElementType()); } LogicalResult vector::ExtractElementOp::verify() { @@ -1116,15 +1115,15 @@ // Fold extractelement(broadcast(X)) -> X. if (auto broadcast = getVector().getDefiningOp()) - if (!broadcast.getSource().getType().isa()) + if (!llvm::isa(broadcast.getSource().getType())) return broadcast.getSource(); if (!pos || !src) return {}; - auto srcElements = src.cast().getValues(); + auto srcElements = llvm::cast(src).getValues(); - auto attr = pos.dyn_cast(); + auto attr = llvm::dyn_cast(pos); uint64_t posIdx = attr.getInt(); return srcElements[posIdx]; @@ -1155,7 +1154,7 @@ OpaqueProperties properties, RegionRange, SmallVectorImpl &inferredReturnTypes) { ExtractOp::Adaptor op(operands, attributes); - auto vectorType = op.getVector().getType().cast(); + auto vectorType = llvm::cast(op.getVector().getType()); if (static_cast(op.getPosition().size()) == vectorType.getRank()) { inferredReturnTypes.push_back(vectorType.getElementType()); } else { @@ -1170,7 +1169,7 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { // Allow extracting 1-element vectors instead of scalars. auto isCompatible = [](TypeRange l, TypeRange r) { - auto vectorType = l.front().dyn_cast(); + auto vectorType = llvm::dyn_cast(l.front()); return vectorType && vectorType.getShape().equals({1}) && vectorType.getElementType() == r.front(); }; @@ -1187,7 +1186,7 @@ return emitOpError( "expected position attribute of rank smaller than vector rank"); for (const auto &en : llvm::enumerate(positionAttr)) { - auto attr = en.value().dyn_cast(); + auto attr = llvm::dyn_cast(en.value()); if (!attr || attr.getInt() < 0 || attr.getInt() >= getSourceVectorType().getDimSize(en.index())) return emitOpError("expected position attribute #") @@ -1451,7 +1450,8 @@ if (extractOp.getType() == source.getType()) return source; auto getRank = [](Type type) { - return type.isa() ? type.cast().getRank() : 0; + return llvm::isa(type) ? llvm::cast(type).getRank() + : 0; }; // If splat or broadcast from a scalar, just return the source scalar. unsigned broadcastSrcRank = getRank(source.getType()); @@ -1462,8 +1462,8 @@ if (extractResultRank >= broadcastSrcRank) return Value(); // Check that the dimension of the result haven't been broadcasted. - auto extractVecType = extractOp.getType().dyn_cast(); - auto broadcastVecType = source.getType().dyn_cast(); + auto extractVecType = llvm::dyn_cast(extractOp.getType()); + auto broadcastVecType = llvm::dyn_cast(source.getType()); if (extractVecType && broadcastVecType && extractVecType.getShape() != broadcastVecType.getShape().take_back(extractResultRank)) @@ -1502,13 +1502,14 @@ return type.getShape().take_back(n + 1).front(); }; int64_t destinationRank = - extractOp.getType().isa() - ? extractOp.getType().cast().getRank() + llvm::isa(extractOp.getType()) + ? llvm::cast(extractOp.getType()).getRank() : 0; if (destinationRank > shapeCastOp.getSourceVectorType().getRank()) return Value(); if (destinationRank > 0) { - auto destinationType = extractOp.getResult().getType().cast(); + auto destinationType = + llvm::cast(extractOp.getResult().getType()); for (int64_t i = 0; i < destinationRank; i++) { // The lowest dimension of of the destination must match the lowest // dimension of the shapecast op source. @@ -1574,7 +1575,7 @@ sliceOffsets.pop_back(); } unsigned destinationRank = 0; - if (auto vecType = extractOp.getType().dyn_cast()) + if (auto vecType = llvm::dyn_cast(extractOp.getType())) destinationRank = vecType.getRank(); // The dimensions of the result need to be untouched by the // extractStridedSlice op. @@ -1595,8 +1596,8 @@ /// Fold extract_op fed from a chain of insertStridedSlice ops. static Value foldExtractStridedOpFromInsertChain(ExtractOp op) { - int64_t destinationRank = op.getType().isa() - ? op.getType().cast().getRank() + int64_t destinationRank = llvm::isa(op.getType()) + ? llvm::cast(op.getType()).getRank() : 0; auto insertOp = op.getVector().getDefiningOp(); while (insertOp) { @@ -1608,7 +1609,7 @@ auto extractOffsets = extractVector(op.getPosition()); if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) { - return attr.cast().getInt() != 1; + return llvm::cast(attr).getInt() != 1; })) return Value(); bool disjoint = false; @@ -1691,7 +1692,9 @@ if (extractOp.getType() == source.getType()) return failure(); auto getRank = [](Type type) { - return type.isa() ? type.cast().getRank() : 0; + return llvm::isa(type) + ? llvm::cast(type).getRank() + : 0; }; unsigned broadcastSrcRank = getRank(source.getType()); unsigned extractResultRank = getRank(extractOp.getType()); @@ -1703,7 +1706,7 @@ // Special case if broadcast src is a 0D vector. if (extractResultRank == 0) { - assert(broadcastSrcRank == 0 && source.getType().isa()); + assert(broadcastSrcRank == 0 && llvm::isa(source.getType())); rewriter.replaceOpWithNewOp(extractOp, source); return success(); } @@ -1726,11 +1729,11 @@ Attribute vectorCst; if (!matchPattern(sourceVector, m_Constant(&vectorCst))) return failure(); - auto splat = vectorCst.dyn_cast(); + auto splat = llvm::dyn_cast(vectorCst); if (!splat) return failure(); TypedAttr newAttr = splat.getSplatValue(); - if (auto vecDstType = extractOp.getType().dyn_cast()) + if (auto vecDstType = llvm::dyn_cast(extractOp.getType())) newAttr = DenseElementsAttr::get(vecDstType, newAttr); rewriter.replaceOpWithNewOp(extractOp, newAttr); return success(); @@ -1752,12 +1755,12 @@ if (!matchPattern(sourceVector, m_Constant(&vectorCst))) return failure(); - auto vecTy = sourceVector.getType().cast(); + auto vecTy = llvm::cast(sourceVector.getType()); if (vecTy.isScalable()) return failure(); // The splat case is handled by `ExtractOpSplatConstantFolder`. - auto dense = vectorCst.dyn_cast(); + auto dense = llvm::dyn_cast(vectorCst); if (!dense || dense.isSplat()) return failure(); @@ -1770,7 +1773,7 @@ auto denseValuesBegin = dense.value_begin() + elemBeginPosition; TypedAttr newAttr; - if (auto resVecTy = extractOp.getType().dyn_cast()) { + if (auto resVecTy = llvm::dyn_cast(extractOp.getType())) { SmallVector elementValues( denseValuesBegin, denseValuesBegin + resVecTy.getNumElements()); newAttr = DenseElementsAttr::get(resVecTy, elementValues); @@ -1794,7 +1797,7 @@ static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl &results) { for (auto attr : arrayAttr) - results.push_back(attr.cast().getInt()); + results.push_back(llvm::cast(attr).getInt()); } //===----------------------------------------------------------------------===// @@ -1830,7 +1833,7 @@ llvm::SetVector BroadcastOp::computeBroadcastedUnitDims() { // Scalar broadcast is without any unit dim broadcast. - auto srcVectorType = getSourceType().dyn_cast(); + auto srcVectorType = llvm::dyn_cast(getSourceType()); if (!srcVectorType) return {}; return ::computeBroadcastedUnitDims(srcVectorType.getShape(), @@ -1867,7 +1870,7 @@ Location loc = value.getLoc(); Type elementType = getElementTypeOrSelf(value.getType()); - VectorType srcVectorType = value.getType().dyn_cast(); + VectorType srcVectorType = llvm::dyn_cast(value.getType()); VectorType dstVectorType = VectorType::get(dstShape, elementType); // Step 2. If scalar -> dstShape broadcast, just do it. @@ -1952,7 +1955,7 @@ getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType)) return BroadcastableToResult::Success; // From now on, only vectors broadcast. - VectorType srcVectorType = srcType.dyn_cast(); + VectorType srcVectorType = llvm::dyn_cast(srcType); if (!srcVectorType) return BroadcastableToResult::SourceTypeNotAVector; @@ -2074,7 +2077,7 @@ int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) + (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0)); for (const auto &en : llvm::enumerate(maskAttr)) { - auto attr = en.value().dyn_cast(); + auto attr = llvm::dyn_cast(en.value()); if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize) return emitOpError("mask index #") << (en.index() + 1) << " out of range"; } @@ -2087,7 +2090,7 @@ OpaqueProperties properties, RegionRange, SmallVectorImpl &inferredReturnTypes) { ShuffleOp::Adaptor op(operands, attributes); - auto v1Type = op.getV1().getType().cast(); + auto v1Type = llvm::cast(op.getV1().getType()); auto v1Rank = v1Type.getRank(); // Construct resulting type: leading dimension matches mask // length, all trailing dimensions match the operands. @@ -2132,7 +2135,8 @@ if (!lhs || !rhs) return {}; - auto lhsType = lhs.cast().getType().cast(); + auto lhsType = + llvm::cast(llvm::cast(lhs).getType()); // Only support 1-D for now to avoid complicated n-D DenseElementsAttr // manipulation. if (lhsType.getRank() != 1) @@ -2140,8 +2144,8 @@ int64_t lhsSize = lhsType.getDimSize(0); SmallVector results; - auto lhsElements = lhs.cast().getValues(); - auto rhsElements = rhs.cast().getValues(); + auto lhsElements = llvm::cast(lhs).getValues(); + auto rhsElements = llvm::cast(rhs).getValues(); for (const auto &index : this->getMask().getAsValueRange()) { int64_t i = index.getZExtValue(); if (i >= lhsSize) { @@ -2170,7 +2174,7 @@ if (mask.size() != 1) return failure(); Type resType = VectorType::Builder(v1VectorType).setShape({1}); - if (mask[0].cast().getInt() == 0) + if (llvm::cast(mask[0]).getInt() == 0) rewriter.replaceOpWithNewOp(shuffleOp, resType, shuffleOp.getV1()); else @@ -2242,11 +2246,11 @@ if (!src || !dst || !pos) return {}; - auto dstElements = dst.cast().getValues(); + auto dstElements = llvm::cast(dst).getValues(); SmallVector results(dstElements); - auto attr = pos.dyn_cast(); + auto attr = llvm::dyn_cast(pos); uint64_t posIdx = attr.getInt(); results[posIdx] = src; @@ -2282,7 +2286,7 @@ if (positionAttr.size() > static_cast(destVectorType.getRank())) return emitOpError( "expected position attribute of rank smaller than dest vector rank"); - auto srcVectorType = getSourceType().dyn_cast(); + auto srcVectorType = llvm::dyn_cast(getSourceType()); if (srcVectorType && (static_cast(srcVectorType.getRank()) + positionAttr.size() != static_cast(destVectorType.getRank()))) @@ -2293,7 +2297,7 @@ return emitOpError( "expected position attribute rank to match the dest vector rank"); for (const auto &en : llvm::enumerate(positionAttr)) { - auto attr = en.value().dyn_cast(); + auto attr = llvm::dyn_cast(en.value()); if (!attr || attr.getInt() < 0 || attr.getInt() >= destVectorType.getDimSize(en.index())) return emitOpError("expected position attribute #") @@ -2314,7 +2318,7 @@ LogicalResult matchAndRewrite(InsertOp insertOp, PatternRewriter &rewriter) const override { - auto srcVecType = insertOp.getSourceType().dyn_cast(); + auto srcVecType = llvm::dyn_cast(insertOp.getSourceType()); if (!srcVecType || insertOp.getDestVectorType().getNumElements() != srcVecType.getNumElements()) return failure(); @@ -2372,7 +2376,7 @@ !destVector.hasOneUse()) return failure(); - auto denseDest = vectorDestCst.cast(); + auto denseDest = llvm::cast(vectorDestCst); Value sourceValue = op.getSource(); Attribute sourceCst; @@ -2387,7 +2391,7 @@ linearize(completePositions, computeStrides(destTy.getShape())); SmallVector insertedValues; - if (auto denseSource = sourceCst.dyn_cast()) + if (auto denseSource = llvm::dyn_cast(sourceCst)) llvm::append_range(insertedValues, denseSource.getValues()); else insertedValues.push_back(sourceCst); @@ -2455,7 +2459,7 @@ int64_t max, StringRef attrName, bool halfOpen = true) { for (auto attr : arrayAttr) { - auto val = attr.cast().getInt(); + auto val = llvm::cast(attr).getInt(); auto upper = max; if (!halfOpen) upper += 1; @@ -2476,8 +2480,7 @@ bool halfOpen = true, int64_t min = 0) { for (auto [index, attrDimPair] : llvm::enumerate(llvm::zip_first(arrayAttr, shape))) { - int64_t val = - std::get<0>(attrDimPair).template cast().getInt(); + int64_t val = llvm::cast(std::get<0>(attrDimPair)).getInt(); int64_t max = std::get<1>(attrDimPair); if (!halfOpen) max += 1; @@ -2501,8 +2504,8 @@ assert(arrayAttr2.size() <= shape.size()); for (auto [index, it] : llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) { - auto val1 = std::get<0>(it).template cast().getInt(); - auto val2 = std::get<1>(it).template cast().getInt(); + auto val1 = llvm::cast(std::get<0>(it)).getInt(); + auto val2 = llvm::cast(std::get<1>(it)).getInt(); int64_t max = std::get<2>(it); if (!halfOpen) max += 1; @@ -2643,7 +2646,7 @@ !destVector.hasOneUse()) return failure(); - auto denseDest = vectorDestCst.cast(); + auto denseDest = llvm::cast(vectorDestCst); TypedValue sourceValue = op.getSource(); Attribute sourceCst; @@ -2666,7 +2669,7 @@ // increasing linearized position indices. // Because the destination may have higher dimensionality then the slice, // we keep track of two overlapping sets of positions and offsets. - auto denseSlice = sourceCst.cast(); + auto denseSlice = llvm::cast(sourceCst); auto sliceValuesIt = denseSlice.value_begin(); auto newValues = llvm::to_vector(denseDest.getValues()); SmallVector currDestPosition(offsets.begin(), offsets.end()); @@ -2735,8 +2738,8 @@ if (operandsInfo.size() < 2) return parser.emitError(parser.getNameLoc(), "expected at least 2 operands"); - VectorType vLHS = tLHS.dyn_cast(); - VectorType vRHS = tRHS.dyn_cast(); + VectorType vLHS = llvm::dyn_cast(tLHS); + VectorType vRHS = llvm::dyn_cast(tRHS); if (!vLHS) return parser.emitError(parser.getNameLoc(), "expected vector type for operand #1"); @@ -2771,7 +2774,7 @@ LogicalResult OuterProductOp::verify() { Type tRHS = getOperandTypeRHS(); VectorType vLHS = getOperandVectorTypeLHS(), - vRHS = tRHS.dyn_cast(), + vRHS = llvm::dyn_cast(tRHS), vACC = getOperandVectorTypeACC(), vRES = getResultVectorType(); if (vLHS.getRank() != 1) @@ -2897,7 +2900,7 @@ shape.reserve(vectorType.getRank()); unsigned idx = 0; for (unsigned e = offsets.size(); idx < e; ++idx) - shape.push_back(sizes[idx].cast().getInt()); + shape.push_back(llvm::cast(sizes[idx]).getInt()); for (unsigned e = vectorType.getShape().size(); idx < e; ++idx) shape.push_back(vectorType.getShape()[idx]); @@ -2913,7 +2916,7 @@ auto sizesAttr = getVectorSubscriptAttr(builder, sizes); auto stridesAttr = getVectorSubscriptAttr(builder, strides); result.addTypes( - inferStridedSliceOpResultType(source.getType().cast(), + inferStridedSliceOpResultType(llvm::cast(source.getType()), offsetsAttr, sizesAttr, stridesAttr)); result.addAttribute(getOffsetsAttrStrName(), offsetsAttr); result.addAttribute(getSizesAttrStrName(), sizesAttr); @@ -2967,7 +2970,7 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { // Helper to extract integer out of ArrayAttr. auto getElement = [](ArrayAttr array, int idx) { - return array[idx].cast().getInt(); + return llvm::cast(array[idx]).getInt(); }; ArrayAttr extractOffsets = op.getOffsets(); ArrayAttr extractStrides = op.getStrides(); @@ -3112,7 +3115,7 @@ if (!matchPattern(sourceVector, m_Constant(&vectorCst))) return failure(); - auto splat = vectorCst.dyn_cast(); + auto splat = llvm::dyn_cast(vectorCst); if (!splat) return failure(); @@ -3141,7 +3144,7 @@ return failure(); // The splat case is handled by `StridedSliceSplatConstantFolder`. - auto dense = vectorCst.dyn_cast(); + auto dense = llvm::dyn_cast(vectorCst); if (!dense || dense.isSplat()) return failure(); @@ -3149,7 +3152,7 @@ if (extractStridedSliceOp.hasNonUnitStrides()) return failure(); - auto sourceVecTy = sourceVector.getType().cast(); + auto sourceVecTy = llvm::cast(sourceVector.getType()); ArrayRef sourceShape = sourceVecTy.getShape(); SmallVector sourceStrides = computeStrides(sourceShape); @@ -3201,9 +3204,10 @@ auto broadcast = op.getVector().getDefiningOp(); if (!broadcast) return failure(); - auto srcVecType = broadcast.getSource().getType().dyn_cast(); + auto srcVecType = + llvm::dyn_cast(broadcast.getSource().getType()); unsigned srcRank = srcVecType ? srcVecType.getRank() : 0; - auto dstVecType = op.getType().cast(); + auto dstVecType = llvm::cast(op.getType()); unsigned dstRank = dstVecType.getRank(); unsigned rankDiff = dstRank - srcRank; // Check if the most inner dimensions of the source of the broadcast are the @@ -3269,7 +3273,7 @@ VectorType vectorType, Value source, ValueRange indices, AffineMapAttr permutationMapAttr, /*optional*/ ArrayAttr inBoundsAttr) { - Type elemType = source.getType().cast().getElementType(); + Type elemType = llvm::cast(source.getType()).getElementType(); Value padding = builder.create( result.location, elemType, builder.getZeroAttr(elemType)); build(builder, result, vectorType, source, indices, permutationMapAttr, @@ -3295,7 +3299,7 @@ ValueRange indices, Value padding, std::optional> inBounds) { AffineMap permutationMap = getTransferMinorIdentityMap( - source.getType().cast(), vectorType); + llvm::cast(source.getType()), vectorType); auto permutationMapAttr = AffineMapAttr::get(permutationMap); auto inBoundsAttr = (inBounds && !inBounds.value().empty()) ? builder.getBoolArrayAttr(inBounds.value()) @@ -3311,7 +3315,7 @@ VectorType vectorType, Value source, ValueRange indices, std::optional> inBounds) { - Type elemType = source.getType().cast().getElementType(); + Type elemType = llvm::cast(source.getType()).getElementType(); Value padding = builder.create( result.location, elemType, builder.getZeroAttr(elemType)); build(builder, result, vectorType, source, indices, padding, inBounds); @@ -3356,13 +3360,13 @@ "Use in_bounds instead."); } - if (!shapedType.isa()) + if (!llvm::isa(shapedType)) return op->emitOpError( "requires source to be a memref or ranked tensor type"); auto elementType = shapedType.getElementType(); DataLayout dataLayout = DataLayout::closest(op); - if (auto vectorElementType = elementType.dyn_cast()) { + if (auto vectorElementType = llvm::dyn_cast(elementType)) { // Memref or tensor has vector element type. unsigned sourceVecSize = dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) * @@ -3425,7 +3429,7 @@ << " vs inBounds of size: " << inBounds.size(); for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i) if (permutationMap.getResult(i).isa() && - !inBounds.getValue()[i].cast().getValue()) + !llvm::cast(inBounds.getValue()[i]).getValue()) return op->emitOpError("requires broadcast dimensions to be in-bounds"); } @@ -3440,7 +3444,7 @@ bool elideInBounds = true; if (auto inBounds = op.in_bounds()) { for (auto attr : *inBounds) { - if (attr.template cast().getValue()) { + if (llvm::cast(attr).getValue()) { elideInBounds = false; break; } @@ -3496,10 +3500,10 @@ if (types.size() != 2) return parser.emitError(typesLoc, "requires two types"); auto indexType = builder.getIndexType(); - auto shapedType = types[0].dyn_cast(); - if (!shapedType || !shapedType.isa()) + auto shapedType = llvm::dyn_cast(types[0]); + if (!shapedType || !llvm::isa(shapedType)) return parser.emitError(typesLoc, "requires memref or ranked tensor type"); - VectorType vectorType = types[1].dyn_cast(); + VectorType vectorType = llvm::dyn_cast(types[1]); if (!vectorType) return parser.emitError(typesLoc, "requires vector type"); auto permMapAttrName = TransferReadOp::getPermutationMapAttrStrName(); @@ -3509,7 +3513,7 @@ permMap = getTransferMinorIdentityMap(shapedType, vectorType); result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap)); } else { - permMap = permMapAttr.cast().getValue(); + permMap = llvm::cast(permMapAttr).getValue(); } if (parser.resolveOperand(sourceInfo, shapedType, result.operands) || parser.resolveOperands(indexInfo, indexType, result.operands) || @@ -3517,7 +3521,7 @@ result.operands)) return failure(); if (hasMask.succeeded()) { - if (shapedType.getElementType().dyn_cast()) + if (llvm::dyn_cast(shapedType.getElementType())) return parser.emitError( maskInfo.location, "does not support masks with vector element type"); // Instead of adding the mask type as an op type, compute it based on the @@ -3554,7 +3558,8 @@ getInBounds() ? *getInBounds() : ArrayAttr()))) return failure(); - if (auto sourceVectorElementType = sourceElementType.dyn_cast()) { + if (auto sourceVectorElementType = + llvm::dyn_cast(sourceElementType)) { // Source has vector element type. // Check that 'sourceVectorElementType' and 'paddingType' types match. if (sourceVectorElementType != paddingType) @@ -3647,7 +3652,7 @@ /// %v0 /// ``` static Value foldRAW(TransferReadOp readOp) { - if (!readOp.getShapedType().isa()) + if (!llvm::isa(readOp.getShapedType())) return {}; auto defWrite = readOp.getSource().getDefiningOp(); while (defWrite) { @@ -3682,7 +3687,7 @@ void TransferReadOp::getEffects( SmallVectorImpl> &effects) { - if (getShapedType().isa()) + if (llvm::isa(getShapedType())) effects.emplace_back(MemoryEffects::Read::get(), getSource(), SideEffects::DefaultResource::get()); } @@ -3818,7 +3823,7 @@ LogicalResult matchAndRewrite(TransferReadOp readOp, PatternRewriter &rewriter) const override { if (readOp.hasOutOfBoundsDim() || - !readOp.getShapedType().isa()) + !llvm::isa(readOp.getShapedType())) return failure(); auto defWrite = readOp.getSource().getDefiningOp(); if (!defWrite) @@ -3889,7 +3894,7 @@ AffineMapAttr permutationMapAttr, /*optional*/ Value mask, /*optional*/ ArrayAttr inBoundsAttr) { - Type resultType = dest.getType().dyn_cast(); + Type resultType = llvm::dyn_cast(dest.getType()); build(builder, result, resultType, vector, dest, indices, permutationMapAttr, mask, inBoundsAttr); } @@ -3922,9 +3927,9 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result, Value vector, Value dest, ValueRange indices, std::optional> inBounds) { - auto vectorType = vector.getType().cast(); + auto vectorType = llvm::cast(vector.getType()); AffineMap permutationMap = getTransferMinorIdentityMap( - dest.getType().cast(), vectorType); + llvm::cast(dest.getType()), vectorType); build(builder, result, vector, dest, indices, permutationMap, inBounds); } @@ -3949,11 +3954,11 @@ if (types.size() != 2) return parser.emitError(typesLoc, "requires two types"); auto indexType = builder.getIndexType(); - VectorType vectorType = types[0].dyn_cast(); + VectorType vectorType = llvm::dyn_cast(types[0]); if (!vectorType) return parser.emitError(typesLoc, "requires vector type"); - ShapedType shapedType = types[1].dyn_cast(); - if (!shapedType || !shapedType.isa()) + ShapedType shapedType = llvm::dyn_cast(types[1]); + if (!shapedType || !llvm::isa(shapedType)) return parser.emitError(typesLoc, "requires memref or ranked tensor type"); auto permMapAttrName = TransferWriteOp::getPermutationMapAttrStrName(); auto permMapAttr = result.attributes.get(permMapAttrName); @@ -3962,14 +3967,14 @@ permMap = getTransferMinorIdentityMap(shapedType, vectorType); result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap)); } else { - permMap = permMapAttr.cast().getValue(); + permMap = llvm::cast(permMapAttr).getValue(); } if (parser.resolveOperand(vectorInfo, vectorType, result.operands) || parser.resolveOperand(sourceInfo, shapedType, result.operands) || parser.resolveOperands(indexInfo, indexType, result.operands)) return failure(); if (hasMask.succeeded()) { - if (shapedType.getElementType().dyn_cast()) + if (llvm::dyn_cast(shapedType.getElementType())) return parser.emitError( maskInfo.location, "does not support masks with vector element type"); auto maskType = inferTransferOpMaskType(vectorType, permMap); @@ -3980,7 +3985,7 @@ builder.getDenseI32ArrayAttr( {1, 1, static_cast(indexInfo.size()), static_cast(hasMask.succeeded())})); - return failure(shapedType.isa() && + return failure(llvm::isa(shapedType) && parser.addTypeToList(shapedType, result.types)); } @@ -4052,7 +4057,7 @@ if (write.getTransferRank() == 0) return failure(); auto rankedTensorType = - write.getSource().getType().dyn_cast(); + llvm::dyn_cast(write.getSource().getType()); // If not operating on tensors, bail. if (!rankedTensorType) return failure(); @@ -4119,7 +4124,7 @@ /// ``` static LogicalResult foldWAR(TransferWriteOp write, SmallVectorImpl &results) { - if (!write.getSource().getType().isa()) + if (!llvm::isa(write.getSource().getType())) return failure(); auto read = write.getVector().getDefiningOp(); if (!read) @@ -4149,7 +4154,7 @@ void TransferWriteOp::getEffects( SmallVectorImpl> &effects) { - if (getShapedType().isa()) + if (llvm::isa(getShapedType())) effects.emplace_back(MemoryEffects::Write::get(), getSource(), SideEffects::DefaultResource::get()); } @@ -4184,7 +4189,7 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TransferWriteOp writeOp, PatternRewriter &rewriter) const override { - if (!writeOp.getShapedType().isa()) + if (!llvm::isa(writeOp.getShapedType())) return failure(); vector::TransferWriteOp writeToModify = writeOp; @@ -4439,7 +4444,7 @@ // Checks for vector memrefs. Type memElemTy = memRefTy.getElementType(); - if (auto memVecTy = memElemTy.dyn_cast()) { + if (auto memVecTy = llvm::dyn_cast(memElemTy)) { if (memVecTy != resVecTy) return emitOpError("base memref and result vector types should match"); memElemTy = memVecTy.getElementType(); @@ -4471,7 +4476,7 @@ // Checks for vector memrefs. Type memElemTy = memRefTy.getElementType(); - if (auto memVecTy = memElemTy.dyn_cast()) { + if (auto memVecTy = llvm::dyn_cast(memElemTy)) { if (memVecTy != valueVecTy) return emitOpError( "base memref and valueToStore vector types should match"); @@ -4604,7 +4609,7 @@ VectorType resVType = getVectorType(); ShapedType baseType = getBaseType(); - if (!baseType.isa()) + if (!llvm::isa(baseType)) return emitOpError("requires base to be a memref or ranked tensor type"); if (resVType.getElementType() != baseType.getElementType()) @@ -4864,8 +4869,10 @@ } LogicalResult ShapeCastOp::verify() { - auto sourceVectorType = getSource().getType().dyn_cast_or_null(); - auto resultVectorType = getResult().getType().dyn_cast_or_null(); + auto sourceVectorType = + llvm::dyn_cast_or_null(getSource().getType()); + auto resultVectorType = + llvm::dyn_cast_or_null(getResult().getType()); // Check if source/result are of vector type. if (sourceVectorType && resultVectorType) @@ -4885,8 +4892,8 @@ return otherOp.getSource(); // Only allows valid transitive folding. - VectorType srcType = otherOp.getSource().getType().cast(); - VectorType resultType = getResult().getType().cast(); + VectorType srcType = llvm::cast(otherOp.getSource().getType()); + VectorType resultType = llvm::cast(getResult().getType()); if (srcType.getRank() < resultType.getRank()) { if (!isValidShapeCast(srcType.getShape(), resultType.getShape())) return {}; @@ -4923,11 +4930,11 @@ if (!constantOp) return failure(); // Only handle splat for now. - auto dense = constantOp.getValue().dyn_cast(); + auto dense = llvm::dyn_cast(constantOp.getValue()); if (!dense) return failure(); auto newAttr = - DenseElementsAttr::get(shapeCastOp.getType().cast(), + DenseElementsAttr::get(llvm::cast(shapeCastOp.getType()), dense.getSplatValue()); rewriter.replaceOpWithNewOp(shapeCastOp, newAttr); return success(); @@ -4950,7 +4957,7 @@ return failure(); auto broadcastSourceVectorType = - broadcastOp.getSourceType().dyn_cast(); + llvm::dyn_cast(broadcastOp.getSourceType()); auto broadcastSourceShape = broadcastSourceVectorType ? broadcastSourceVectorType.getShape() : ArrayRef{}; @@ -5029,7 +5036,7 @@ Type srcElemType = getSourceVectorType().getElementType(); Type dstElemType = getResultVectorType().getElementType(); - if (auto floatPack = sourceConstant.dyn_cast()) { + if (auto floatPack = llvm::dyn_cast(sourceConstant)) { if (floatPack.isSplat()) { auto splat = floatPack.getSplatValue(); @@ -5046,11 +5053,11 @@ } } - if (auto intPack = sourceConstant.dyn_cast()) { + if (auto intPack = llvm::dyn_cast(sourceConstant)) { if (intPack.isSplat()) { auto splat = intPack.getSplatValue(); - if (dstElemType.isa()) { + if (llvm::isa(dstElemType)) { uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth(); uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth(); @@ -5075,7 +5082,7 @@ //===----------------------------------------------------------------------===// static SmallVector extractShape(MemRefType memRefType) { - auto vectorType = memRefType.getElementType().dyn_cast(); + auto vectorType = llvm::dyn_cast(memRefType.getElementType()); SmallVector res(memRefType.getShape().begin(), memRefType.getShape().end()); if (vectorType) @@ -5088,7 +5095,7 @@ void TypeCastOp::build(OpBuilder &builder, OperationState &result, Value source) { result.addOperands(source); - MemRefType memRefType = source.getType().cast(); + MemRefType memRefType = llvm::cast(source.getType()); VectorType vectorType = VectorType::get(extractShape(memRefType), getElementTypeOrSelf(getElementTypeOrSelf(memRefType))); @@ -5126,7 +5133,7 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result, Value vector, ArrayRef transp) { - VectorType vt = vector.getType().cast(); + VectorType vt = llvm::cast(vector.getType()); SmallVector transposedShape(vt.getRank()); for (unsigned i = 0; i < transp.size(); ++i) transposedShape[i] = vt.getShape()[transp[i]]; @@ -5170,7 +5177,7 @@ return emitOpError("transposition length mismatch: ") << size; SmallVector seen(rank, false); for (const auto &ta : llvm::enumerate(transpAttr)) { - int64_t i = ta.value().cast().getInt(); + int64_t i = llvm::cast(ta.value()).getInt(); if (i < 0 || i >= rank) return emitOpError("transposition index out of range: ") << i; if (seen[i]) @@ -5239,7 +5246,7 @@ if (!bcastOp) return failure(); - auto srcVectorType = bcastOp.getSourceType().dyn_cast(); + auto srcVectorType = llvm::dyn_cast(bcastOp.getSourceType()); if (!srcVectorType || srcVectorType.getNumElements() == 1) { rewriter.replaceOpWithNewOp( transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource()); @@ -5324,12 +5331,12 @@ //===----------------------------------------------------------------------===// LogicalResult ConstantMaskOp::verify() { - auto resultType = getResult().getType().cast(); + auto resultType = llvm::cast(getResult().getType()); // Check the corner case of 0-D vectors first. if (resultType.getRank() == 0) { if (getMaskDimSizes().size() != 1) return emitError("array attr must have length 1 for 0-D vectors"); - auto dim = getMaskDimSizes()[0].cast().getInt(); + auto dim = llvm::cast(getMaskDimSizes()[0]).getInt(); if (dim != 0 && dim != 1) return emitError("mask dim size must be either 0 or 1 for 0-D vectors"); return success(); @@ -5344,7 +5351,7 @@ auto resultShape = resultType.getShape(); SmallVector maskDimSizes; for (const auto &it : llvm::enumerate(getMaskDimSizes())) { - int64_t attrValue = it.value().cast().getInt(); + int64_t attrValue = llvm::cast(it.value()).getInt(); if (attrValue < 0 || attrValue > resultShape[it.index()]) return emitOpError( "array attr of size out of bounds of vector result dimension size"); @@ -5363,7 +5370,7 @@ // `vector.constant_mask`. In the future, a convention could be established // to decide if a specific dimension value could be considered as "all set". if (resultType.isScalable() && - getMaskDimSizes()[0].cast().getInt() != 0) + llvm::cast(getMaskDimSizes()[0]).getInt() != 0) return emitOpError("expected mask dim sizes for scalable masks to be 0"); return success(); } @@ -5381,14 +5388,14 @@ } LogicalResult CreateMaskOp::verify() { - auto vectorType = getResult().getType().cast(); + auto vectorType = llvm::cast(getResult().getType()); // Verify that an operand was specified for each result vector each dimension. if (vectorType.getRank() == 0) { if (getNumOperands() != 1) return emitOpError( "must specify exactly one operand for 0-D create_mask"); } else if (getNumOperands() != - getResult().getType().cast().getRank()) { + llvm::cast(getResult().getType()).getRank()) { return emitOpError( "must specify an operand for each result vector dimension"); } @@ -5413,7 +5420,7 @@ // CreateMaskOp for scalable vectors can be folded only if all dimensions // are negative or zero. - if (auto vType = createMaskOp.getType().dyn_cast()) { + if (auto vType = llvm::dyn_cast(createMaskOp.getType())) { if (vType.isScalable()) for (auto opDim : createMaskOp.getOperands()) { APInt intVal; @@ -5615,7 +5622,7 @@ "expects result type to match maskable operation result type"); if (llvm::count_if(maskableOp->getResultTypes(), - [](Type t) { return t.isa(); }) > 1) + [](Type t) { return llvm::isa(t); }) > 1) return emitOpError("multiple vector results not supported"); // Mask checks. @@ -5759,7 +5766,7 @@ SmallVector coreAttr = {getWarpSizeAttrName()}; auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName()); - p << "[" << warpSizeAttr.cast().getInt() << "]"; + p << "[" << llvm::cast(warpSizeAttr).getInt() << "]"; if (!getArgs().empty()) p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")"; @@ -5872,8 +5879,8 @@ // If the types matches there is no distribution. if (expanded == distributed) return success(); - auto expandedVecType = expanded.dyn_cast(); - auto distributedVecType = distributed.dyn_cast(); + auto expandedVecType = llvm::dyn_cast(expanded); + auto distributedVecType = llvm::dyn_cast(distributed); if (!expandedVecType || !distributedVecType) return op->emitOpError("expected vector type for distributed operands."); if (expandedVecType.getRank() != distributedVecType.getRank() || @@ -5940,7 +5947,7 @@ case CombiningKind::ADD: if (t1.isIntOrIndex() && tAcc.isIntOrIndex()) result = b.createOrFold(loc, v1, acc); - else if (t1.isa() && tAcc.isa()) + else if (llvm::isa(t1) && llvm::isa(tAcc)) result = b.createOrFold(loc, v1, acc); else llvm_unreachable("invalid value types for ADD reduction"); @@ -5950,12 +5957,12 @@ result = b.createOrFold(loc, v1, acc); break; case CombiningKind::MAXF: - assert(t1.isa() && tAcc.isa() && + assert(llvm::isa(t1) && llvm::isa(tAcc) && "expected float values"); result = b.createOrFold(loc, v1, acc); break; case CombiningKind::MINF: - assert(t1.isa() && tAcc.isa() && + assert(llvm::isa(t1) && llvm::isa(tAcc) && "expected float values"); result = b.createOrFold(loc, v1, acc); break; @@ -5978,7 +5985,7 @@ case CombiningKind::MUL: if (t1.isIntOrIndex() && tAcc.isIntOrIndex()) result = b.createOrFold(loc, v1, acc); - else if (t1.isa() && tAcc.isa()) + else if (llvm::isa(t1) && llvm::isa(tAcc)) result = b.createOrFold(loc, v1, acc); else llvm_unreachable("invalid value types for MUL reduction"); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -231,7 +231,7 @@ Value vectorMask = maskableOp.getMaskingOp().getMask(); auto maskCastedType = VectorType::get( vectorShape, - vectorMask.getType().cast().getElementType()); + llvm::cast(vectorMask.getType()).getElementType()); newVectorMask = rewriter.create(loc, maskCastedType, vectorMask); } @@ -413,7 +413,7 @@ srcVectorType.getElementType()); auto accType = VectorType::get(ArrayRef{1}, srcVectorType.getElementType()); - assert(!multiReductionOp.getDestType().isa() && + assert(!llvm::isa(multiReductionOp.getDestType()) && "multi_reduction with a single dimension expects a scalar result"); // If the unique dim is reduced and we insert a parallel in front, we need a @@ -427,7 +427,7 @@ loc, accType, multiReductionOp.getAcc()); Value castMask; if (maskableOp.isMasked()) { - auto maskType = mask.getType().cast(); + auto maskType = llvm::cast(mask.getType()); auto castMaskType = VectorType::get(ArrayRef{1, maskType.getShape().back()}, maskType.getElementType()); diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -66,14 +66,14 @@ case AffineExprKind::Constant: return expr.cast().getValue(); case AffineExprKind::DimId: - if (auto attr = operandConsts[expr.cast().getPosition()] - .dyn_cast_or_null()) + if (auto attr = llvm::dyn_cast_or_null( + operandConsts[expr.cast().getPosition()])) return attr.getInt(); return std::nullopt; case AffineExprKind::SymbolId: - if (auto attr = operandConsts[numDims + - expr.cast().getPosition()] - .dyn_cast_or_null()) + if (auto attr = llvm::dyn_cast_or_null( + operandConsts[numDims + + expr.cast().getPosition()])) return attr.getInt(); return std::nullopt; } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -91,7 +91,7 @@ // it is a function (avoiding a grammar ambiguity). bool wrapped = op->getNumResults() != 1; if (!wrapped && op->getResult(0).getType() && - op->getResult(0).getType().isa()) + llvm::isa(op->getResult(0).getType())) wrapped = true; if (wrapped) @@ -254,7 +254,7 @@ bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const { return elementsAttrElementLimit && *elementsAttrElementLimit < int64_t(attr.getNumElements()) && - !attr.isa(); + !llvm::isa(attr); } /// Return the size limit for printing large ElementsAttr. @@ -803,8 +803,8 @@ attr.getDialect().printAttribute(attr, *this); // Process the builtin attributes. - } else if (attr.isa()) { + } else if (llvm::isa(attr)) { return; } else if (auto dictAttr = dyn_cast(attr)) { for (const NamedAttribute &nestedAttr : dictAttr.getValue()) { @@ -833,9 +833,9 @@ // Don't print the type if we must elide it, or if it is a None type. if (!elideType) { - if (auto typedAttr = attr.dyn_cast()) { + if (auto typedAttr = llvm::dyn_cast(attr)) { Type attrType = typedAttr.getType(); - if (!attrType.isa()) + if (!llvm::isa(attrType)) printType(attrType); } } @@ -845,10 +845,10 @@ return type.getDialect().printType(type, *this); // Only visit the layout of memref if it isn't the identity. - if (auto memrefTy = type.dyn_cast()) { + if (auto memrefTy = llvm::dyn_cast(type)) { printType(memrefTy.getElementType()); MemRefLayoutAttrInterface layout = memrefTy.getLayout(); - if (!layout.isa() || !layout.isIdentity()) + if (!llvm::isa(layout) || !layout.isIdentity()) printAttribute(memrefTy.getLayout()); if (memrefTy.getMemorySpace()) printAttribute(memrefTy.getMemorySpace()); @@ -1418,7 +1418,7 @@ void SSANameState::numberValuesInRegion(Region ®ion) { auto setBlockArgNameFn = [&](Value arg, StringRef name) { assert(!valueIDs.count(arg) && "arg numbered multiple times"); - assert(arg.cast().getOwner()->getParent() == ®ion && + assert(llvm::cast(arg).getOwner()->getParent() == ®ion && "arg not defined in current region"); setValueName(arg, name); }; @@ -1479,7 +1479,7 @@ setValueName(result, name); // Record the result number for groups not anchored at 0. - if (int resultNo = result.cast().getResultNumber()) + if (int resultNo = llvm::cast(result).getResultNumber()) resultGroups.push_back(resultNo); }; // Operations can customize the printing of block names in OpAsmOpInterface. @@ -1878,7 +1878,7 @@ // Print the child if it isn't unknown. auto childLoc = loc.getChildLoc(); - if (!childLoc.isa()) { + if (!llvm::isa(childLoc)) { os << '('; printLocationInternal(childLoc, pretty); os << ')'; @@ -1891,8 +1891,8 @@ os << "callsite("; printLocationInternal(callee, pretty); if (pretty) { - if (callee.isa()) { - if (caller.isa()) { + if (llvm::isa(callee)) { + if (llvm::isa(caller)) { os << " at "; } else { os << newLine << " at "; @@ -2100,19 +2100,19 @@ AttrTypeElision typeElision) { if (!isa(attr.getDialect())) { printDialectAttribute(attr); - } else if (auto opaqueAttr = attr.dyn_cast()) { + } else if (auto opaqueAttr = llvm::dyn_cast(attr)) { printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(), opaqueAttr.getAttrData()); - } else if (attr.isa()) { + } else if (llvm::isa(attr)) { os << "unit"; return; - } else if (auto dictAttr = attr.dyn_cast()) { + } else if (auto dictAttr = llvm::dyn_cast(attr)) { os << '{'; interleaveComma(dictAttr.getValue(), [&](NamedAttribute attr) { printNamedAttribute(attr); }); os << '}'; - } else if (auto intAttr = attr.dyn_cast()) { + } else if (auto intAttr = llvm::dyn_cast(attr)) { Type intType = intAttr.getType(); if (intType.isSignlessInteger(1)) { os << (intAttr.getValue().getBoolValue() ? "true" : "false"); @@ -2132,24 +2132,24 @@ if (typeElision == AttrTypeElision::May && intType.isSignlessInteger(64)) return; - } else if (auto floatAttr = attr.dyn_cast()) { + } else if (auto floatAttr = llvm::dyn_cast(attr)) { printFloatValue(floatAttr.getValue(), os); // FloatAttr elides the type if F64. if (typeElision == AttrTypeElision::May && floatAttr.getType().isF64()) return; - } else if (auto strAttr = attr.dyn_cast()) { + } else if (auto strAttr = llvm::dyn_cast(attr)) { printEscapedString(strAttr.getValue()); - } else if (auto arrayAttr = attr.dyn_cast()) { + } else if (auto arrayAttr = llvm::dyn_cast(attr)) { os << '['; interleaveComma(arrayAttr.getValue(), [&](Attribute attr) { printAttribute(attr, AttrTypeElision::May); }); os << ']'; - } else if (auto affineMapAttr = attr.dyn_cast()) { + } else if (auto affineMapAttr = llvm::dyn_cast(attr)) { os << "affine_map<"; affineMapAttr.getValue().print(os); os << '>'; @@ -2157,7 +2157,7 @@ // AffineMap always elides the type. return; - } else if (auto integerSetAttr = attr.dyn_cast()) { + } else if (auto integerSetAttr = llvm::dyn_cast(attr)) { os << "affine_set<"; integerSetAttr.getValue().print(os); os << '>'; @@ -2165,17 +2165,18 @@ // IntegerSet always elides the type. return; - } else if (auto typeAttr = attr.dyn_cast()) { + } else if (auto typeAttr = llvm::dyn_cast(attr)) { printType(typeAttr.getValue()); - } else if (auto refAttr = attr.dyn_cast()) { + } else if (auto refAttr = llvm::dyn_cast(attr)) { printSymbolReference(refAttr.getRootReference().getValue(), os); for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) { os << "::"; printSymbolReference(nestedRef.getValue(), os); } - } else if (auto intOrFpEltAttr = attr.dyn_cast()) { + } else if (auto intOrFpEltAttr = + llvm::dyn_cast(attr)) { if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) { printElidedElementsAttr(os); } else { @@ -2184,7 +2185,7 @@ os << '>'; } - } else if (auto strEltAttr = attr.dyn_cast()) { + } else if (auto strEltAttr = llvm::dyn_cast(attr)) { if (printerFlags.shouldElideElementsAttr(strEltAttr)) { printElidedElementsAttr(os); } else { @@ -2193,7 +2194,7 @@ os << '>'; } - } else if (auto sparseEltAttr = attr.dyn_cast()) { + } else if (auto sparseEltAttr = llvm::dyn_cast(attr)) { if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) || printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) { printElidedElementsAttr(os); @@ -2207,9 +2208,9 @@ } os << '>'; } - } else if (auto stridedLayoutAttr = attr.dyn_cast()) { + } else if (auto stridedLayoutAttr = llvm::dyn_cast(attr)) { stridedLayoutAttr.print(os); - } else if (auto denseArrayAttr = attr.dyn_cast()) { + } else if (auto denseArrayAttr = llvm::dyn_cast(attr)) { os << "array<"; printType(denseArrayAttr.getElementType()); if (!denseArrayAttr.empty()) { @@ -2218,20 +2219,21 @@ } os << ">"; return; - } else if (auto resourceAttr = attr.dyn_cast()) { + } else if (auto resourceAttr = + llvm::dyn_cast(attr)) { os << "dense_resource<"; printResourceHandle(resourceAttr.getRawHandle()); os << ">"; - } else if (auto locAttr = attr.dyn_cast()) { + } else if (auto locAttr = llvm::dyn_cast(attr)) { printLocation(locAttr); } else { llvm::report_fatal_error("Unknown builtin attribute"); } // Don't print the type if we must elide it, or if it is a None type. if (typeElision != AttrTypeElision::Must) { - if (auto typedAttr = attr.dyn_cast()) { + if (auto typedAttr = llvm::dyn_cast(attr)) { Type attrType = typedAttr.getType(); - if (!attrType.isa()) { + if (!llvm::isa(attrType)) { os << " : "; printType(attrType); } @@ -2300,10 +2302,10 @@ void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr, bool allowHex) { - if (auto stringAttr = attr.dyn_cast()) + if (auto stringAttr = llvm::dyn_cast(attr)) return printDenseStringElementsAttr(stringAttr); - printDenseIntOrFPElementsAttr(attr.cast(), + printDenseIntOrFPElementsAttr(llvm::cast(attr), allowHex); } @@ -2333,12 +2335,12 @@ return; } - if (ComplexType complexTy = elementType.dyn_cast()) { + if (ComplexType complexTy = llvm::dyn_cast(elementType)) { Type complexElementType = complexTy.getElementType(); // Note: The if and else below had a common lambda function which invoked // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2 // and hence was replaced. - if (complexElementType.isa()) { + if (llvm::isa(complexElementType)) { auto valueIt = attr.value_begin>(); printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { auto complexValue = *(valueIt + index); @@ -2365,7 +2367,7 @@ printDenseIntElement(*(valueIt + index), os, elementType); }); } else { - assert(elementType.isa() && "unexpected element type"); + assert(llvm::isa(elementType) && "unexpected element type"); auto valueIt = attr.value_begin(); printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { printFloatValue(*(valueIt + index), os); @@ -2397,7 +2399,7 @@ if (type.isIntOrIndex()) { printDenseIntElement(value, getStream(), type); } else { - APFloat fltVal(type.cast().getFloatSemantics(), value); + APFloat fltVal(llvm::cast(type).getFloatSemantics(), value); printFloatValue(fltVal, getStream()); } }; @@ -2447,7 +2449,7 @@ interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); }); os << ") -> "; ArrayRef results = funcTy.getResults(); - if (results.size() == 1 && !results[0].isa()) { + if (results.size() == 1 && !llvm::isa(results[0])) { printType(results[0]); } else { os << '('; @@ -2506,7 +2508,7 @@ } printType(memrefTy.getElementType()); MemRefLayoutAttrInterface layout = memrefTy.getLayout(); - if (!layout.isa() || !layout.isIdentity()) { + if (!llvm::isa(layout) || !layout.isIdentity()) { os << ", "; printAttribute(memrefTy.getLayout(), AttrTypeElision::May); } @@ -2580,7 +2582,7 @@ ::printKeywordOrString(attr.getName().strref(), os); // Pretty printing elides the attribute value for unit attributes. - if (attr.getValue().isa()) + if (llvm::isa(attr.getValue())) return; os << " = "; diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -33,7 +33,7 @@ /// Return the bit width which DenseElementsAttr should use for this type. inline size_t getDenseElementBitWidth(Type eltType) { // Align the width for complex to 8 to make storage and interpretation easier. - if (ComplexType comp = eltType.dyn_cast()) + if (ComplexType comp = llvm::dyn_cast(eltType)) return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2; if (eltType.isIndex()) return IndexType::kInternalStorageBitWidth; diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -46,7 +46,9 @@ assert(name.size() != 0 && "expected valid attribute name"); } -StringAttr NamedAttribute::getName() const { return name.cast(); } +StringAttr NamedAttribute::getName() const { + return llvm::cast(name); +} Dialect *NamedAttribute::getNameDialect() const { return getName().getReferencedDialect(); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -316,14 +316,15 @@ } TypedAttr Builder::getZeroAttr(Type type) { - if (type.isa()) + if (llvm::isa(type)) return getFloatAttr(type, 0.0); - if (type.isa()) + if (llvm::isa(type)) return getIndexAttr(0); - if (auto integerType = type.dyn_cast()) - return getIntegerAttr(type, APInt(type.cast().getWidth(), 0)); - if (type.isa()) { - auto vtType = type.cast(); + if (auto integerType = llvm::dyn_cast(type)) + return getIntegerAttr(type, + APInt(llvm::cast(type).getWidth(), 0)); + if (llvm::isa(type)) { + auto vtType = llvm::cast(type); auto element = getZeroAttr(vtType.getElementType()); if (!element) return {}; diff --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp --- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp +++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp @@ -53,7 +53,7 @@ } uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef index) { - ShapedType shapeType = type.cast(); + ShapedType shapeType = llvm::cast(type); assert(isValidIndex(shapeType, index) && "expected valid multi-dimensional index"); diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -300,11 +300,12 @@ LogicalResult FloatAttr::verify(function_ref emitError, Type type, APFloat value) { // Verify that the type is correct. - if (!type.isa()) + if (!llvm::isa(type)) return emitError() << "expected floating point type"; // Verify that the type semantics match that of the value. - if (&type.cast().getFloatSemantics() != &value.getSemantics()) { + if (&llvm::cast(type).getFloatSemantics() != + &value.getSemantics()) { return emitError() << "FloatAttr type doesn't match the type implied by its value"; } @@ -321,11 +322,11 @@ } FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { - return get(ctx, value, {}).cast(); + return llvm::cast(get(ctx, value, {})); } FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) { - return get(value, {}).cast(); + return llvm::cast(get(value, {})); } FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) { @@ -370,14 +371,14 @@ LogicalResult IntegerAttr::verify(function_ref emitError, Type type, APInt value) { - if (IntegerType integerType = type.dyn_cast()) { + if (IntegerType integerType = llvm::dyn_cast(type)) { if (integerType.getWidth() != value.getBitWidth()) return emitError() << "integer type bit width (" << integerType.getWidth() << ") doesn't match value bit width (" << value.getBitWidth() << ")"; return success(); } - if (type.isa()) { + if (llvm::isa(type)) { if (value.getBitWidth() != IndexType::kInternalStorageBitWidth) return emitError() << "value bit width (" << value.getBitWidth() @@ -390,7 +391,7 @@ BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) { auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value)); - return attr.cast(); + return llvm::cast(attr); } //===----------------------------------------------------------------------===// @@ -403,7 +404,7 @@ } bool BoolAttr::classof(Attribute attr) { - IntegerAttr intAttr = attr.dyn_cast(); + IntegerAttr intAttr = llvm::dyn_cast(attr); return intAttr && intAttr.getType().isSignlessInteger(1); } @@ -600,21 +601,21 @@ attr.getAsOpaquePointer(), index) {} Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { - auto owner = getFromOpaquePointer(base).cast(); + auto owner = llvm::cast(getFromOpaquePointer(base)); Type eltTy = owner.getElementType(); - if (auto intEltTy = eltTy.dyn_cast()) + if (auto intEltTy = llvm::dyn_cast(eltTy)) return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); - if (eltTy.isa()) + if (llvm::isa(eltTy)) return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); - if (auto floatEltTy = eltTy.dyn_cast()) { + if (auto floatEltTy = llvm::dyn_cast(eltTy)) { IntElementIterator intIt(owner, index); FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); return FloatAttr::get(eltTy, *floatIt); } - if (auto complexTy = eltTy.dyn_cast()) { + if (auto complexTy = llvm::dyn_cast(eltTy)) { auto complexEltTy = complexTy.getElementType(); ComplexIntElementIterator complexIntIt(owner, index); - if (complexEltTy.isa()) { + if (llvm::isa(complexEltTy)) { auto value = *complexIntIt; auto real = IntegerAttr::get(complexEltTy, value.real()); auto imag = IntegerAttr::get(complexEltTy, value.imag()); @@ -623,14 +624,14 @@ } ComplexFloatElementIterator complexFloatIt( - complexEltTy.cast().getFloatSemantics(), complexIntIt); + llvm::cast(complexEltTy).getFloatSemantics(), complexIntIt); auto value = *complexFloatIt; auto real = FloatAttr::get(complexEltTy, value.real()); auto imag = FloatAttr::get(complexEltTy, value.imag()); return ArrayAttr::get(complexTy.getContext(), ArrayRef{real, imag}); } - if (owner.isa()) { + if (llvm::isa(owner)) { ArrayRef vals = owner.getRawStringData(); return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); } @@ -673,7 +674,7 @@ std::complex, std::complex, std::complex>( attr.getRawData().data(), attr.isSplat(), dataIndex) { - auto complexType = attr.getElementType().cast(); + auto complexType = llvm::cast(attr.getElementType()); bitWidth = getDenseElementBitWidth(complexType.getElementType()); } @@ -713,7 +714,7 @@ IntegerType::SignednessSemantics signedness = IntegerType::Signless> struct DenseArrayAttrIntUtil { static bool checkElementType(Type eltType) { - auto type = eltType.dyn_cast(); + auto type = llvm::dyn_cast(eltType); if (!type || type.getWidth() != width) return false; return type.getSignedness() == signedness; @@ -860,7 +861,7 @@ template bool DenseArrayAttrImpl::classof(Attribute attr) { - if (auto denseArray = attr.dyn_cast()) + if (auto denseArray = llvm::dyn_cast(attr)) return DenseArrayAttrUtil::checkElementType(denseArray.getElementType()); return false; } @@ -884,7 +885,7 @@ /// Method for support type inquiry through isa, cast and dyn_cast. bool DenseElementsAttr::classof(Attribute attr) { - return attr.isa(); + return llvm::isa(attr); } DenseElementsAttr DenseElementsAttr::get(ShapedType type, @@ -894,20 +895,19 @@ Type eltType = type.getElementType(); // Take care complex type case first. - if (auto complexType = eltType.dyn_cast()) { + if (auto complexType = llvm::dyn_cast(eltType)) { if (complexType.getElementType().isIntOrIndex()) { SmallVector> complexValues; complexValues.reserve(values.size()); for (Attribute attr : values) { - assert(attr.isa() && - "expected ArrayAttr for complex"); - auto arrayAttr = attr.cast(); + assert(llvm::isa(attr) && "expected ArrayAttr for complex"); + auto arrayAttr = llvm::cast(attr); assert(arrayAttr.size() == 2 && "expected 2 element for complex"); auto attr0 = arrayAttr[0]; auto attr1 = arrayAttr[1]; complexValues.push_back( - std::complex(attr0.cast().getValue(), - attr1.cast().getValue())); + std::complex(llvm::cast(attr0).getValue(), + llvm::cast(attr1).getValue())); } return DenseElementsAttr::get(type, complexValues); } @@ -915,14 +915,14 @@ SmallVector> complexValues; complexValues.reserve(values.size()); for (Attribute attr : values) { - assert(attr.isa() && "expected ArrayAttr for complex"); - auto arrayAttr = attr.cast(); + assert(llvm::isa(attr) && "expected ArrayAttr for complex"); + auto arrayAttr = llvm::cast(attr); assert(arrayAttr.size() == 2 && "expected 2 element for complex"); auto attr0 = arrayAttr[0]; auto attr1 = arrayAttr[1]; complexValues.push_back( - std::complex(attr0.cast().getValue(), - attr1.cast().getValue())); + std::complex(llvm::cast(attr0).getValue(), + llvm::cast(attr1).getValue())); } return DenseElementsAttr::get(type, complexValues); } @@ -933,9 +933,9 @@ SmallVector stringValues; stringValues.reserve(values.size()); for (Attribute attr : values) { - assert(attr.isa() && + assert(llvm::isa(attr) && "expected string value for non integer/index/float element"); - stringValues.push_back(attr.cast().getValue()); + stringValues.push_back(llvm::cast(attr).getValue()); } return get(type, stringValues); } @@ -949,12 +949,12 @@ llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT)); APInt intVal; for (unsigned i = 0, e = values.size(); i < e; ++i) { - if (auto floatAttr = values[i].dyn_cast()) { + if (auto floatAttr = llvm::dyn_cast(values[i])) { assert(floatAttr.getType() == eltType && "expected float attribute type to equal element type"); intVal = floatAttr.getValue().bitcastToAPInt(); } else { - auto intAttr = values[i].cast(); + auto intAttr = llvm::cast(values[i]); assert(intAttr.getType() == eltType && "expected integer attribute type to equal element type"); intVal = intAttr.getValue(); @@ -1015,8 +1015,8 @@ } DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef> values) { - ComplexType complex = type.getElementType().cast(); - assert(complex.getElementType().isa()); + ComplexType complex = llvm::cast(type.getElementType()); + assert(llvm::isa(complex.getElementType())); assert(hasSameElementsOrSplat(type, values)); size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; ArrayRef intVals(reinterpret_cast(values.data()), @@ -1029,7 +1029,7 @@ // element type of 'type'. DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef values) { - assert(type.getElementType().isa()); + assert(llvm::isa(type.getElementType())); assert(hasSameElementsOrSplat(type, values)); size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); @@ -1037,8 +1037,8 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef> values) { - ComplexType complex = type.getElementType().cast(); - assert(complex.getElementType().isa()); + ComplexType complex = llvm::cast(type.getElementType()); + assert(llvm::isa(complex.getElementType())); assert(hasSameElementsOrSplat(type, values)); ArrayRef apVals(reinterpret_cast(values.data()), values.size() * 2); @@ -1104,11 +1104,11 @@ // Check that the element type is either float or integer or index. if (!isInt) - return type.isa(); + return llvm::isa(type); if (type.isIndex()) return true; - auto intType = type.dyn_cast(); + auto intType = llvm::dyn_cast(type); if (!intType) return false; @@ -1142,8 +1142,8 @@ bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, bool isSigned) const { return ::isValidIntOrFloat( - getElementType().cast().getElementType(), dataEltSize / 2, - isInt, isSigned); + llvm::cast(getElementType()).getElementType(), + dataEltSize / 2, isInt, isSigned); } /// Returns true if this attribute corresponds to a splat, i.e. if all element @@ -1154,7 +1154,7 @@ /// Return if the given complex type has an integer element type. static bool isComplexOfIntType(Type type) { - return type.cast().getElementType().isa(); + return llvm::isa(llvm::cast(type).getElementType()); } auto DenseElementsAttr::tryGetComplexIntValues() const @@ -1168,7 +1168,7 @@ auto DenseElementsAttr::tryGetFloatValues() const -> FailureOr> { - auto eltTy = getElementType().dyn_cast(); + auto eltTy = llvm::dyn_cast(getElementType()); if (!eltTy) return failure(); const auto &elementSemantics = eltTy.getFloatSemantics(); @@ -1179,10 +1179,10 @@ auto DenseElementsAttr::tryGetComplexFloatValues() const -> FailureOr> { - auto complexTy = getElementType().dyn_cast(); + auto complexTy = llvm::dyn_cast(getElementType()); if (!complexTy) return failure(); - auto eltTy = complexTy.getElementType().dyn_cast(); + auto eltTy = llvm::dyn_cast(complexTy.getElementType()); if (!eltTy) return failure(); const auto &semantics = eltTy.getFloatSemantics(); @@ -1331,7 +1331,7 @@ bool isInt, bool isSigned) { assert(::isValidIntOrFloat( - type.getElementType().cast().getElementType(), + llvm::cast(type.getElementType()).getElementType(), dataEltSize / 2, isInt, isSigned)); int64_t numElements = data.size() / dataEltSize; @@ -1404,7 +1404,7 @@ ShapedType type) { size_t numElements = type.getNumElements(); Type elementType = type.getElementType(); - if (ComplexType complexTy = elementType.dyn_cast()) { + if (ComplexType complexTy = llvm::dyn_cast(elementType)) { elementType = complexTy.getElementType(); numElements = numElements * 2; } @@ -1470,8 +1470,8 @@ /// Method for supporting type inquiry through isa, cast and dyn_cast. bool DenseFPElementsAttr::classof(Attribute attr) { - if (auto denseAttr = attr.dyn_cast()) - return denseAttr.getType().getElementType().isa(); + if (auto denseAttr = llvm::dyn_cast(attr)) + return llvm::isa(denseAttr.getType().getElementType()); return false; } @@ -1489,7 +1489,7 @@ /// Method for supporting type inquiry through isa, cast and dyn_cast. bool DenseIntElementsAttr::classof(Attribute attr) { - if (auto denseAttr = attr.dyn_cast()) + if (auto denseAttr = llvm::dyn_cast(attr)) return denseAttr.getType().getElementType().isIntOrIndex(); return false; } @@ -1525,7 +1525,7 @@ template struct DenseResourceElementsAttrIntUtil { static bool checkElementType(Type eltType) { - IntegerType type = eltType.dyn_cast(); + IntegerType type = llvm::dyn_cast(eltType); if (!type || type.getWidth() != width) return false; return isSigned ? !type.isUnsigned() : !type.isSigned(); @@ -1582,8 +1582,8 @@ "size mismatch between expected element width and blob size"); assert(DenseResourceAttrUtil::checkElementType(type.getElementType()) && "invalid shape element type for provided type `T`"); - return DenseResourceElementsAttr::get(type, blobName, std::move(blob)) - .template cast>(); + return llvm::cast>( + DenseResourceElementsAttr::get(type, blobName, std::move(blob))); } template @@ -1596,7 +1596,7 @@ template bool DenseResourceElementsAttrBase::classof(Attribute attr) { - auto resourceAttr = attr.dyn_cast(); + auto resourceAttr = llvm::dyn_cast(attr); return resourceAttr && DenseResourceAttrUtil::checkElementType( resourceAttr.getElementType()); } @@ -1624,13 +1624,13 @@ /// Get a zero APFloat for the given sparse attribute. APFloat SparseElementsAttr::getZeroAPFloat() const { - auto eltType = getElementType().cast(); + auto eltType = llvm::cast(getElementType()); return APFloat(eltType.getFloatSemantics()); } /// Get a zero APInt for the given sparse attribute. APInt SparseElementsAttr::getZeroAPInt() const { - auto eltType = getElementType().cast(); + auto eltType = llvm::cast(getElementType()); return APInt::getZero(eltType.getWidth()); } @@ -1639,14 +1639,14 @@ auto eltType = getElementType(); // Handle floating point elements. - if (eltType.isa()) + if (llvm::isa(eltType)) return FloatAttr::get(eltType, 0); // Handle complex elements. - if (auto complexTy = eltType.dyn_cast()) { + if (auto complexTy = llvm::dyn_cast(eltType)) { auto eltType = complexTy.getElementType(); Attribute zero; - if (eltType.isa()) + if (llvm::isa(eltType)) zero = FloatAttr::get(eltType, 0); else // must be integer zero = IntegerAttr::get(eltType, 0); @@ -1655,7 +1655,7 @@ } // Handle string type. - if (getValues().isa()) + if (llvm::isa(getValues())) return StringAttr::get("", eltType); // Otherwise, this is an integer. diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -48,15 +48,15 @@ : OpAsmDialectInterface(dialect), blobManager(mgr) {} AliasResult getAlias(Attribute attr, raw_ostream &os) const override { - if (attr.isa()) { + if (llvm::isa(attr)) { os << "map"; return AliasResult::OverridableAlias; } - if (attr.isa()) { + if (llvm::isa(attr)) { os << "set"; return AliasResult::OverridableAlias; } - if (attr.isa()) { + if (llvm::isa(attr)) { os << "loc"; return AliasResult::OverridableAlias; } @@ -64,7 +64,7 @@ } AliasResult getAlias(Type type, raw_ostream &os) const final { - if (auto tupleType = type.dyn_cast()) { + if (auto tupleType = llvm::dyn_cast(type)) { if (tupleType.size() > 16) { os << "tuple"; return AliasResult::OverridableAlias; @@ -145,7 +145,7 @@ // interface. This needs a linear search, but is called only once per data // layout object construction that is used for repeated queries. for (NamedAttribute attr : getOperation()->getAttrs()) - if (auto spec = attr.getValue().dyn_cast()) + if (auto spec = llvm::dyn_cast(attr.getValue())) return spec; return {}; } @@ -168,7 +168,7 @@ StringRef layoutSpecAttrName; DataLayoutSpecInterface layoutSpec; for (const NamedAttribute &na : (*this)->getAttrs()) { - if (auto spec = na.getValue().dyn_cast()) { + if (auto spec = llvm::dyn_cast(na.getValue())) { if (layoutSpec) { InFlightDiagnostic diag = emitOpError() << "expects at most one data layout attribute"; diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp --- a/mlir/lib/IR/BuiltinDialectBytecode.cpp +++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp @@ -32,7 +32,7 @@ static unsigned getIntegerBitWidth(DialectBytecodeReader &reader, Type type) { if (auto intType = dyn_cast(type)) { return intType.getWidth(); - } else if (type.isa()) { + } else if (llvm::isa(type)) { return IndexType::kInternalStorageBitWidth; } reader.emitError() diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -244,10 +244,10 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) { if (!scale) return VectorType(); - if (auto et = getElementType().dyn_cast()) + if (auto et = llvm::dyn_cast(getElementType())) if (auto scaledEt = et.scaleElementBitwidth(scale)) return VectorType::get(getShape(), scaledEt, getNumScalableDims()); - if (auto et = getElementType().dyn_cast()) + if (auto et = llvm::dyn_cast(getElementType())) if (auto scaledEt = et.scaleElementBitwidth(scale)) return VectorType::get(getShape(), scaledEt, getNumScalableDims()); return VectorType(); @@ -305,8 +305,8 @@ // Note: Non standard/builtin types are allowed to exist within tensor // types. Dialects are expected to verify that tensor types have a valid // element type within that dialect. - return type.isa() || + return llvm::isa(type) || !llvm::isa(type.getDialect()); } @@ -321,7 +321,7 @@ for (int64_t s : shape) if (s < 0 && !ShapedType::isDynamic(s)) return emitError() << "invalid tensor dimension size"; - if (auto v = encoding.dyn_cast_or_null()) + if (auto v = llvm::dyn_cast_or_null(encoding)) if (failed(v.verifyEncoding(shape, elementType, emitError))) return failure(); return checkTensorElementType(emitError, elementType); @@ -426,9 +426,9 @@ if (originalType == candidateReducedType) return SliceVerificationResult::Success; - ShapedType originalShapedType = originalType.cast(); + ShapedType originalShapedType = llvm::cast(originalType); ShapedType candidateReducedShapedType = - candidateReducedType.cast(); + llvm::cast(candidateReducedType); // Rank and size logic is valid for all ShapedTypes. ArrayRef originalShape = originalShapedType.getShape(); @@ -459,7 +459,7 @@ return true; // Supported built-in attributes. - if (memorySpace.isa()) + if (llvm::isa(memorySpace)) return true; // Allow custom dialect attributes. @@ -478,7 +478,7 @@ } Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) { - IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null(); + IntegerAttr intMemorySpace = llvm::dyn_cast_or_null(memorySpace); if (intMemorySpace && intMemorySpace.getValue() == 0) return nullptr; @@ -489,10 +489,10 @@ if (!memorySpace) return 0; - assert(memorySpace.isa() && + assert(llvm::isa(memorySpace) && "Using `getMemorySpaceInteger` with non-Integer attribute"); - return static_cast(memorySpace.cast().getInt()); + return static_cast(llvm::cast(memorySpace).getInt()); } unsigned MemRefType::getMemorySpaceAsInt() const { @@ -786,7 +786,7 @@ SmallVectorImpl &strides, int64_t &offset) { // Happy path: the type uses the strided layout directly. - if (auto strided = t.getLayout().dyn_cast()) { + if (auto strided = llvm::dyn_cast(t.getLayout())) { llvm::append_range(strides, strided.getStrides()); offset = strided.getOffset(); return success(); @@ -834,7 +834,7 @@ /// (i32, tensor, f32, i64) void TupleType::getFlattenedTypes(SmallVectorImpl &types) { for (Type type : getTypes()) { - if (auto nestedTuple = type.dyn_cast()) + if (auto nestedTuple = llvm::dyn_cast(type)) nestedTuple.getFlattenedTypes(types); else types.push_back(type); diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -259,7 +259,7 @@ return; auto &os = llvm::errs(); - if (!diag.getLocation().isa()) + if (!llvm::isa(diag.getLocation())) os << diag.getLocation() << ": "; os << "error: "; @@ -448,7 +448,7 @@ if (!fileLoc) { std::string str; llvm::raw_string_ostream strOS(str); - if (!loc.isa()) + if (!llvm::isa(loc)) strOS << loc << ": "; strOS << message; return mgr.PrintMessage(os, SMLoc(), getDiagKind(kind), strOS.str()); @@ -983,7 +983,7 @@ // Print each diagnostic with the format: // ": : " - if (!diag.getLocation().isa()) + if (!llvm::isa(diag.getLocation())) os << diag.getLocation() << ": "; switch (diag.getSeverity()) { case DiagnosticSeverity::Error: diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp --- a/mlir/lib/IR/ExtensibleDialect.cpp +++ b/mlir/lib/IR/ExtensibleDialect.cpp @@ -474,7 +474,7 @@ LogicalResult ExtensibleDialect::printIfDynamicType(Type type, AsmPrinter &printer) { - if (auto dynType = type.dyn_cast()) { + if (auto dynType = llvm::dyn_cast(type)) { dynType.print(printer); return success(); } @@ -496,7 +496,7 @@ LogicalResult ExtensibleDialect::printIfDynamicAttr(Attribute attribute, AsmPrinter &printer) { - if (auto dynAttr = attribute.dyn_cast()) { + if (auto dynAttr = llvm::dyn_cast(attribute)) { dynAttr.print(printer); return success(); } diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -250,14 +250,14 @@ "Invalid number of attributes."); auto &os = p.getStream(); - bool needsParens = types.size() > 1 || types[0].isa() || - (attrs && !attrs[0].cast().empty()); + bool needsParens = types.size() > 1 || llvm::isa(types[0]) || + (attrs && !llvm::cast(attrs[0]).empty()); if (needsParens) os << '('; llvm::interleaveComma(llvm::seq(0, types.size()), os, [&](size_t i) { p.printType(types[i]); if (attrs) - p.printOptionalAttrDict(attrs[i].cast().getValue()); + p.printOptionalAttrDict(llvm::cast(attrs[i]).getValue()); }); if (needsParens) os << ')'; @@ -278,12 +278,13 @@ if (!isExternal) { ArrayRef attrs; if (argAttrs) - attrs = argAttrs[i].cast().getValue(); + attrs = llvm::cast(argAttrs[i]).getValue(); p.printRegionArgument(body.getArgument(i), attrs); } else { p.printType(argTypes[i]); if (argAttrs) - p.printOptionalAttrDict(argAttrs[i].cast().getValue()); + p.printOptionalAttrDict( + llvm::cast(argAttrs[i]).getValue()); } } diff --git a/mlir/lib/IR/FunctionInterfaces.cpp b/mlir/lib/IR/FunctionInterfaces.cpp --- a/mlir/lib/IR/FunctionInterfaces.cpp +++ b/mlir/lib/IR/FunctionInterfaces.cpp @@ -21,14 +21,14 @@ //===----------------------------------------------------------------------===// static bool isEmptyAttrDict(Attribute attr) { - return attr.cast().empty(); + return llvm::cast(attr).empty(); } DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op, unsigned index) { ArrayAttr attrs = op.getArgAttrsAttr(); DictionaryAttr argAttrs = - attrs ? attrs[index].cast() : DictionaryAttr(); + attrs ? llvm::cast(attrs[index]) : DictionaryAttr(); return argAttrs; } @@ -37,7 +37,7 @@ unsigned index) { ArrayAttr attrs = op.getResAttrsAttr(); DictionaryAttr resAttrs = - attrs ? attrs[index].cast() : DictionaryAttr(); + attrs ? llvm::cast(attrs[index]) : DictionaryAttr(); return resAttrs; } @@ -288,7 +288,7 @@ newArgAttrs.reserve(argAttrs.size()); for (unsigned i = 0, e = argIndices.size(); i < e; ++i) if (!argIndices[i]) - newArgAttrs.emplace_back(argAttrs[i].cast()); + newArgAttrs.emplace_back(llvm::cast(argAttrs[i])); setAllArgAttrDicts(op, newArgAttrs); } @@ -309,7 +309,7 @@ newResultAttrs.reserve(resAttrs.size()); for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) if (!resultIndices[i]) - newResultAttrs.emplace_back(resAttrs[i].cast()); + newResultAttrs.emplace_back(llvm::cast(resAttrs[i])); setAllResultAttrDicts(op, newResultAttrs); } diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp --- a/mlir/lib/IR/Location.cpp +++ b/mlir/lib/IR/Location.cpp @@ -64,8 +64,8 @@ /// Methods for support type inquiry through isa, cast, and dyn_cast. bool LocationAttr::classof(Attribute attr) { - return attr.isa(); + return llvm::isa(attr); } //===----------------------------------------------------------------------===// @@ -101,7 +101,7 @@ } } // Otherwise, only add known locations to the set. - if (!loc.isa()) + if (!llvm::isa(loc)) decomposedLocs.insert(loc); } locs = decomposedLocs.getArrayRef(); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -829,11 +829,11 @@ /// If this is a vector type, or a tensor type, return the scalar element type /// that it is built around, otherwise return the type unmodified. static Type getTensorOrVectorElementType(Type type) { - if (auto vec = type.dyn_cast()) + if (auto vec = llvm::dyn_cast(type)) return vec.getElementType(); // Look through tensor> to find the underlying element type. - if (auto tensor = type.dyn_cast()) + if (auto tensor = llvm::dyn_cast(type)) return getTensorOrVectorElementType(tensor.getElementType()); return type; } @@ -867,7 +867,7 @@ LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { for (auto opType : op->getOperandTypes()) { auto type = getTensorOrVectorElementType(opType); - if (!type.isa()) + if (!llvm::isa(type)) return op->emitOpError("requires a float type"); } return success(); @@ -1102,7 +1102,7 @@ LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { for (auto resultType : op->getResultTypes()) - if (!getTensorOrVectorElementType(resultType).isa()) + if (!llvm::isa(getTensorOrVectorElementType(resultType))) return op->emitOpError() << "requires a floating point type"; return success(); @@ -1169,7 +1169,7 @@ LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { auto isMappableType = [](Type type) { - return type.isa(); + return llvm::isa(type); }; auto resultMappableTypes = llvm::to_vector<1>( llvm::make_filter_range(op->getResultTypes(), isMappableType)); diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -59,7 +59,7 @@ } if (!dictionarySorted.getPointer()) dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs)); - return dictionarySorted.getPointer().cast(); + return llvm::cast(dictionarySorted.getPointer()); } /// Add an attribute with the specified name. @@ -405,18 +405,19 @@ OperandRangeRange::OperandRangeRange(OperandRange operands, Attribute operandSegments) : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0, - operandSegments.cast().size()) {} + llvm::cast(operandSegments).size()) { +} OperandRange OperandRangeRange::join() const { const OwnerT &owner = getBase(); - ArrayRef sizeData = owner.second.cast(); + ArrayRef sizeData = llvm::cast(owner.second); return OperandRange(owner.first, std::accumulate(sizeData.begin(), sizeData.end(), 0)); } OperandRange OperandRangeRange::dereference(const OwnerT &object, ptrdiff_t index) { - ArrayRef sizeData = object.second.cast(); + ArrayRef sizeData = llvm::cast(object.second); uint32_t startIndex = std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); return OperandRange(object.first + startIndex, *(sizeData.begin() + index)); @@ -508,7 +509,7 @@ // Update any of the provided segment attributes. for (OperandSegment &segment : operandSegments) { - auto attr = segment.second.getValue().cast(); + auto attr = llvm::cast(segment.second.getValue()); SmallVector segments(attr.asArrayRef()); segments[segment.first] += diff; segment.second.setValue( @@ -524,7 +525,8 @@ const MutableOperandRange &operands, NamedAttribute operandSegmentAttr) : MutableOperandRangeRange( OwnerT(operands, operandSegmentAttr), 0, - operandSegmentAttr.getValue().cast().size()) {} + llvm::cast(operandSegmentAttr.getValue()).size()) { +} MutableOperandRange MutableOperandRangeRange::join() const { return getBase().first; @@ -537,7 +539,7 @@ MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object, ptrdiff_t index) { ArrayRef sizeData = - object.second.getValue().cast(); + llvm::cast(object.second.getValue()); uint32_t startIndex = std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); return object.first.slice( @@ -782,8 +784,8 @@ auto sortValues = [](ValueRange values) { SmallVector sortedValues = llvm::to_vector(values); llvm::sort(sortedValues, [](Value a, Value b) { - auto aArg = a.dyn_cast(); - auto bArg = b.dyn_cast(); + auto aArg = llvm::dyn_cast(a); + auto bArg = llvm::dyn_cast(b); // Case 1. Both `a` and `b` are `BlockArgument`s. if (aArg && bArg) { diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -459,7 +459,7 @@ // Verify the visibility attribute. if (Attribute vis = op->getAttr(mlir::SymbolTable::getVisibilityAttrName())) { - StringAttr visStrAttr = vis.dyn_cast(); + StringAttr visStrAttr = llvm::dyn_cast(vis); if (!visStrAttr) return op->emitOpError() << "requires visibility attribute '" << mlir::SymbolTable::getVisibilityAttrName() @@ -669,7 +669,7 @@ // If the references are not pointer equal, check to see if `subRef` is a // prefix of `ref`. - if (ref.isa() || + if (llvm::isa(ref) || ref.getRootReference() != subRef.getRootReference()) return false; @@ -789,7 +789,7 @@ /// Generates a new symbol reference attribute with a new leaf reference. static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr, FlatSymbolRefAttr newLeafAttr) { - if (oldAttr.isa()) + if (llvm::isa(oldAttr)) return newLeafAttr; auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences()); nestedRefs.back() = newLeafAttr; diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -22,7 +22,7 @@ using namespace mlir; Type mlir::getElementTypeOrSelf(Type type) { - if (auto st = type.dyn_cast()) + if (auto st = llvm::dyn_cast(type)) return st.getElementType(); return type; } @@ -32,7 +32,7 @@ } Type mlir::getElementTypeOrSelf(Attribute attr) { - if (auto typedAttr = attr.dyn_cast()) + if (auto typedAttr = llvm::dyn_cast(attr)) return getElementTypeOrSelf(typedAttr.getType()); return {}; } @@ -47,7 +47,7 @@ /// dialect and typeData. bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect, StringRef typeData) { - if (auto opaque = type.dyn_cast()) + if (auto opaque = llvm::dyn_cast(type)) return opaque.getDialectNamespace() == dialect && opaque.getTypeData() == typeData; return false; @@ -76,8 +76,8 @@ /// compatible if at least one is dynamic or both are equal. The element type /// does not matter. LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) { - auto sType1 = type1.dyn_cast(); - auto sType2 = type2.dyn_cast(); + auto sType1 = llvm::dyn_cast(type1); + auto sType2 = llvm::dyn_cast(type2); // Either both or neither type should be shaped. if (!sType1) @@ -120,7 +120,7 @@ /// dims are equal. The element type does not matter. LogicalResult mlir::verifyCompatibleShapes(TypeRange types) { auto shapedTypes = llvm::to_vector<8>(llvm::map_range( - types, [](auto type) { return type.template dyn_cast(); })); + types, [](auto type) { return llvm::dyn_cast(type); })); // Return failure if some, but not all are not shaped. Return early if none // are shaped also. if (llvm::none_of(shapedTypes, [](auto t) { return t; })) @@ -132,7 +132,7 @@ bool hasScalableVecTypes = false; bool hasNonScalableVecTypes = false; for (Type t : types) { - auto vType = t.dyn_cast(); + auto vType = llvm::dyn_cast(t); if (vType && vType.isScalable()) hasScalableVecTypes = true; else @@ -167,9 +167,9 @@ } Type OperandElementTypeIterator::mapElement(Value value) const { - return value.getType().cast().getElementType(); + return llvm::cast(value.getType()).getElementType(); } Type ResultElementTypeIterator::mapElement(Value value) const { - return value.getType().cast().getElementType(); + return llvm::cast(value.getType()).getElementType(); } diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp --- a/mlir/lib/IR/Verifier.cpp +++ b/mlir/lib/IR/Verifier.cpp @@ -302,7 +302,7 @@ } // Block argument case. Block *block1 = op.getBlock(); - Block *block2 = operand.cast().getOwner(); + Block *block2 = llvm::cast(operand).getOwner(); Region *region1 = block1->getParent(); Region *region2 = block2->getParent(); Location loc = UnknownLoc::get(op.getContext()); diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -75,7 +75,7 @@ if (parser.parseRSquare() || parser.parseGreater()) return Attribute(); return parser.getChecked( - parser.getContext(), type.cast(), elements); + parser.getContext(), llvm::cast(type), elements); } void TestI64ElementsAttr::print(AsmPrinter &printer) const { diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -287,18 +287,18 @@ for (DataLayoutEntryInterface entry : params) { // This is for testing purposes only, so assert well-formedness. assert(entry.isTypeEntry() && "unexpected identifier entry"); - assert(entry.getKey().get().isa() && + assert(llvm::isa(entry.getKey().get()) && "wrong type passed in"); - auto array = entry.getValue().dyn_cast(); + auto array = llvm::dyn_cast(entry.getValue()); assert(array && array.getValue().size() == 2 && "expected array of two elements"); - auto kind = array.getValue().front().dyn_cast(); + auto kind = llvm::dyn_cast(array.getValue().front()); (void)kind; assert(kind && (kind.getValue() == "size" || kind.getValue() == "alignment" || kind.getValue() == "preferred") && "unexpected kind"); - assert(array.getValue().back().isa()); + assert(llvm::isa(array.getValue().back())); } return success(); } @@ -306,10 +306,11 @@ unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params, StringRef expectedKind) const { for (DataLayoutEntryInterface entry : params) { - ArrayRef pair = entry.getValue().cast().getValue(); - StringRef kind = pair.front().cast().getValue(); + ArrayRef pair = + llvm::cast(entry.getValue()).getValue(); + StringRef kind = llvm::cast(pair.front()).getValue(); if (kind == expectedKind) - return pair.back().cast().getValue().getZExtValue(); + return llvm::cast(pair.back()).getValue().getZExtValue(); } return 1; } @@ -466,7 +467,7 @@ if (succeeded(printIfDynamicType(type, printer))) return; - auto rec = type.cast(); + auto rec = llvm::cast(type); printer << "test_rec<" << rec.getName(); if (!stack.contains(rec)) { printer << ", "; diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -109,10 +109,10 @@ mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply( transform::TransformResults &results, transform::TransformState &state) { if (getOperation()->getNumOperands() != 0) { - results.set(getResult().cast(), + results.set(llvm::cast(getResult()), getOperation()->getOperand(0).getDefiningOp()); } else { - results.set(getResult().cast(), getOperation()); + results.set(llvm::cast(getResult()), getOperation()); } return DiagnosedSilenceableFailure::success(); } @@ -127,7 +127,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceValueHandleToSelfOperand::apply( transform::TransformResults &results, transform::TransformState &state) { - results.setValues(getOut().cast(), getIn()); + results.setValues(llvm::cast(getOut()), getIn()); return DiagnosedSilenceableFailure::success(); } @@ -249,13 +249,13 @@ for (Value value : values) { std::string note; llvm::raw_string_ostream os(note); - if (auto arg = value.dyn_cast()) { + if (auto arg = llvm::dyn_cast(value)) { os << "a block argument #" << arg.getArgNumber() << " in block #" << std::distance(arg.getOwner()->getParent()->begin(), arg.getOwner()->getIterator()) << " in region #" << arg.getOwner()->getParent()->getRegionNumber(); } else { - os << "an op result #" << value.cast().getResultNumber(); + os << "an op result #" << llvm::cast(value).getResultNumber(); } InFlightDiagnostic diag = ::emitRemark(value.getLoc()) << getMessage(); diag.attachNote() << "value handle points to " << os.str(); @@ -317,7 +317,7 @@ getOperation()))) return DiagnosedSilenceableFailure::definiteFailure(); if (getNumResults() > 0) - results.set(getResult(0).cast(), getOperation()); + results.set(llvm::cast(getResult(0)), getOperation()); return DiagnosedSilenceableFailure::success(); } @@ -339,7 +339,7 @@ transform::TransformState &state) { ArrayRef payloadOps = state.getPayloadOps(getTarget()); auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps)); - results.set(getResult().cast(), reversedOps); + results.set(llvm::cast(getResult()), reversedOps); return DiagnosedSilenceableFailure::success(); } @@ -443,7 +443,8 @@ DiagnosedSilenceableFailure mlir::test::TestCopyPayloadOp::apply(transform::TransformResults &results, transform::TransformState &state) { - results.set(getCopy().cast(), state.getPayloadOps(getHandle())); + results.set(llvm::cast(getCopy()), + state.getPayloadOps(getHandle())); return DiagnosedSilenceableFailure::success(); } @@ -472,7 +473,7 @@ DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload( Location loc, ArrayRef payload) const { for (Attribute attr : payload) { - auto integerAttr = attr.dyn_cast(); + auto integerAttr = llvm::dyn_cast(attr); if (integerAttr && integerAttr.getType().isSignlessInteger(32)) continue; return emitSilenceableError(loc) @@ -534,7 +535,7 @@ if (Value param = getParam()) { values = llvm::to_vector( llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t { - return attr.cast().getValue().getLimitedValue( + return llvm::cast(attr).getValue().getLimitedValue( UINT32_MAX); })); } @@ -544,7 +545,7 @@ llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute { return builder.getI32IntegerAttr(value + getAddendum()); })); - results.setParams(getResult().cast(), result); + results.setParams(llvm::cast(getResult()), result); return DiagnosedSilenceableFailure::success(); } @@ -562,7 +563,7 @@ }); return builder.getI32IntegerAttr(count); })); - results.setParams(getResult().cast(), result); + results.setParams(llvm::cast(getResult()), result); return DiagnosedSilenceableFailure::success(); } @@ -570,12 +571,12 @@ mlir::test::TestProduceIntegerParamWithTypeOp::apply( transform::TransformResults &results, transform::TransformState &state) { Attribute zero = IntegerAttr::get(getType(), 0); - results.setParams(getResult().cast(), zero); + results.setParams(llvm::cast(getResult()), zero); return DiagnosedSilenceableFailure::success(); } LogicalResult mlir::test::TestProduceIntegerParamWithTypeOp::verify() { - if (!getType().isa()) { + if (!llvm::isa(getType())) { return emitOpError() << "expects an integer type"; } return success(); @@ -618,7 +619,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply( transform::TransformResults &results, transform::TransformState &state) { SmallVector null({nullptr}); - results.set(getOut().cast(), null); + results.set(llvm::cast(getOut()), null); return DiagnosedSilenceableFailure::success(); } @@ -630,7 +631,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results, transform::TransformState &state) { - results.setParams(getOut().cast(), Attribute()); + results.setParams(llvm::cast(getOut()), Attribute()); return DiagnosedSilenceableFailure::success(); } @@ -642,7 +643,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply(transform::TransformResults &results, transform::TransformState &state) { - results.setValues(getOut().cast(), Value()); + results.setValues(llvm::cast(getOut()), Value()); return DiagnosedSilenceableFailure::success(); } @@ -662,7 +663,7 @@ DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply( transform::TransformResults &results, transform::TransformState &state) { - results.set(getOut().cast(), state.getPayloadOps(getIn())); + results.set(llvm::cast(getOut()), state.getPayloadOps(getIn())); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp --- a/mlir/test/python/lib/PythonTestCAPI.cpp +++ b/mlir/test/python/lib/PythonTestCAPI.cpp @@ -16,7 +16,7 @@ python_test::PythonTestDialect) bool mlirAttributeIsAPythonTestTestAttribute(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirPythonTestTestAttributeGet(MlirContext context) { @@ -24,7 +24,7 @@ } bool mlirTypeIsAPythonTestTestType(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirPythonTestTestTypeGet(MlirContext context) { diff --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp --- a/mlir/unittests/TableGen/EnumsGenTest.cpp +++ b/mlir/unittests/TableGen/EnumsGenTest.cpp @@ -175,7 +175,7 @@ mlir::Type intType = mlir::IntegerType::get(&ctx, 32); mlir::Attribute intAttr = mlir::IntegerAttr::get(intType, 5); - EXPECT_TRUE(intAttr.isa()); + EXPECT_TRUE(llvm::isa(intAttr)); EXPECT_EQ(intAttr, enumAttr); } @@ -186,10 +186,10 @@ mlir::Attribute intAttr = mlir::IntegerAttr::get( intType, static_cast(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3)); - EXPECT_TRUE(intAttr.isa()); - EXPECT_TRUE(intAttr.isa()); + EXPECT_TRUE(llvm::isa(intAttr)); + EXPECT_TRUE(llvm::isa(intAttr)); intAttr = mlir::IntegerAttr::get( intType, static_cast(BitEnumWithGroup::Bits0To3) | (1u << 6)); - EXPECT_FALSE(intAttr.isa()); + EXPECT_FALSE(llvm::isa(intAttr)); }