diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h --- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h +++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h @@ -38,9 +38,8 @@ // Get the name of the arith fastmath attribute. llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName(); // Remove the source fastmath attribute. - auto arithFMFAttr = - convertedAttr.erase(arithFMFAttrName) - .template dyn_cast_or_null(); + auto arithFMFAttr = llvm::dyn_cast_or_null( + convertedAttr.erase(arithFMFAttrName)); if (arithFMFAttr) { llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName(); convertedAttr.set(targetAttrName, diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -222,7 +222,7 @@ SmallVector indices; indices.reserve(attrs.size()); for (Attribute attr : attrs) - indices.push_back(attr.cast().getInt()); + indices.push_back(llvm::cast(attr).getInt()); return indices; } diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -403,7 +403,7 @@ if (Operation *op = val.getDefiningOp()) { setInsertionPointAfter(op); } else { - auto blockArg = val.cast(); + auto blockArg = llvm::cast(val); setInsertionPointToStart(blockArg.getOwner()); } } diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -389,7 +389,7 @@ !std::is_same::value, T> getSplatValue() const { - return getSplatValue().template cast(); + return llvm::cast(getSplatValue()); } /// Try to get an iterator of the given type to the start of the held element @@ -510,7 +510,7 @@ T>::mapped_iterator_base; /// Map the element to the iterator result type. - T mapElement(Attribute attr) const { return attr.cast(); } + T mapElement(Attribute attr) const { return llvm::cast(attr); } }; template > FailureOr>> @@ -684,7 +684,7 @@ /// Method for support type inquiry through isa, cast and dyn_cast. static bool classof(Attribute attr) { - auto denseAttr = attr.dyn_cast(); + auto denseAttr = llvm::dyn_cast(attr); return denseAttr && denseAttr.isSplat(); } }; @@ -887,7 +887,7 @@ /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(Attribute attr) { - SymbolRefAttr refAttr = attr.dyn_cast(); + SymbolRefAttr refAttr = llvm::dyn_cast(attr); return refAttr && refAttr.getNestedReferences().empty(); } @@ -912,14 +912,13 @@ /// simply wraps the DenseElementsAttr::get calls. template static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg) { - return DenseElementsAttr::get(type, llvm::ArrayRef(arg)) - .template cast(); + return llvm::cast( + DenseElementsAttr::get(type, llvm::ArrayRef(arg))); } template static DenseFPElementsAttr get(const ShapedType &type, const std::initializer_list &list) { - return DenseElementsAttr::get(type, list) - .template cast(); + return llvm::cast(DenseElementsAttr::get(type, list)); } /// Generates a new DenseElementsAttr by mapping each value attribute, and @@ -954,14 +953,13 @@ /// simply wraps the DenseElementsAttr::get calls. template static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg) { - return DenseElementsAttr::get(type, llvm::ArrayRef(arg)) - .template cast(); + return llvm::cast( + DenseElementsAttr::get(type, llvm::ArrayRef(arg))); } template static DenseIntElementsAttr get(const ShapedType &type, const std::initializer_list &list) { - return DenseElementsAttr::get(type, list) - .template cast(); + return llvm::cast(DenseElementsAttr::get(type, list)); } /// Generates a new DenseElementsAttr by mapping each value attribute, and diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -367,20 +367,21 @@ //===----------------------------------------------------------------------===// inline bool BaseMemRefType::classof(Type type) { - return type.isa(); + return llvm::isa(type); } inline bool BaseMemRefType::isValidElementType(Type type) { return type.isIntOrIndexOrFloat() || - type.isa() || - type.isa(); + llvm::isa( + type) || + llvm::isa(type); } inline bool FloatType::classof(Type type) { - return type - .isa(); + return llvm::isa(type); } inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) { @@ -428,7 +429,7 @@ } inline bool TensorType::classof(Type type) { - return type.isa(); + return llvm::isa(type); } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/FunctionInterfaces.h b/mlir/include/mlir/IR/FunctionInterfaces.h --- a/mlir/include/mlir/IR/FunctionInterfaces.h +++ b/mlir/include/mlir/IR/FunctionInterfaces.h @@ -178,7 +178,7 @@ } for (unsigned i = 0; i != numArgs; ++i) { DictionaryAttr argAttrs = - allArgAttrs[i].dyn_cast_or_null(); + llvm::dyn_cast_or_null(allArgAttrs[i]); if (!argAttrs) { return op.emitOpError() << "expects argument attribute dictionary " "to be a DictionaryAttr, but got `" @@ -209,7 +209,7 @@ } for (unsigned i = 0; i != numResults; ++i) { DictionaryAttr resultAttrs = - allResultAttrs[i].dyn_cast_or_null(); + llvm::dyn_cast_or_null(allResultAttrs[i]); if (!resultAttrs) { return op.emitOpError() << "expects result attribute dictionary " "to be a DictionaryAttr, but got `" diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h --- a/mlir/include/mlir/IR/Location.h +++ b/mlir/include/mlir/IR/Location.h @@ -148,12 +148,12 @@ /// Return the metadata associated with this fused location. MetadataT getMetadata() const { - return FusedLoc::getMetadata().template cast(); + return llvm::cast(FusedLoc::getMetadata()); } /// Support llvm style casting. static bool classof(Attribute attr) { - auto fusedLoc = attr.dyn_cast(); + auto fusedLoc = llvm::dyn_cast(attr); return fusedLoc && fusedLoc.getMetadata().isa_and_nonnull(); } }; diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -39,7 +39,7 @@ attr_value_binder(ValueType *bv) : bind_value(bv) {} bool match(const Attribute &attr) { - if (auto intAttr = attr.dyn_cast()) { + if (auto intAttr = llvm::dyn_cast(attr)) { *bind_value = intAttr.getValue(); return true; } @@ -90,7 +90,7 @@ (void)result; assert(succeeded(result) && "expected ConstantLike op to be foldable"); - if (auto attr = foldedOp.front().get().dyn_cast()) { + if (auto attr = llvm::dyn_cast(foldedOp.front().get())) { if (bind_value) *bind_value = attr; return true; @@ -136,10 +136,10 @@ return false; auto type = op->getResult(0).getType(); - if (type.isa()) + if (llvm::isa(type)) return attr_value_binder(bind_value).match(attr); - if (type.isa()) { - if (auto splatAttr = attr.dyn_cast()) { + if (llvm::isa(type)) { + if (auto splatAttr = llvm::dyn_cast(attr)) { return attr_value_binder(bind_value) .match(splatAttr.getSplatValue()); } @@ -173,10 +173,10 @@ return false; auto type = op->getResult(0).getType(); - if (type.isa()) + if (llvm::isa(type)) return attr_value_binder(bind_value).match(attr); - if (type.isa()) { - if (auto splatAttr = attr.dyn_cast()) { + if (llvm::isa(type)) { + if (auto splatAttr = llvm::dyn_cast(attr)) { return attr_value_binder(bind_value) .match(splatAttr.getSplatValue()); } diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -204,7 +204,7 @@ auto &os = getStream() << " -> "; bool wrapped = !llvm::hasSingleElement(types) || - (*types.begin()).template isa(); + llvm::isa((*types.begin())); if (wrapped) os << '('; llvm::interleaveComma(types, *this); @@ -865,7 +865,7 @@ return failure(); // Check for the right kind of attribute. - if (!(result = attr.dyn_cast())) + if (!(result = llvm::dyn_cast(attr))) return emitError(loc, "invalid kind of attribute specified"); return success(); @@ -899,7 +899,7 @@ return failure(); // Check for the right kind of attribute. - result = attr.dyn_cast(); + result = llvm::dyn_cast(attr); if (!result) return emitError(loc, "invalid kind of attribute specified"); @@ -936,7 +936,7 @@ return failure(); // Check for the right kind of attribute. - result = attr.dyn_cast(); + result = llvm::dyn_cast(attr); if (!result) return emitError(loc, "invalid kind of attribute specified"); @@ -970,7 +970,7 @@ return failure(); // Check for the right kind of attribute. - result = attr.dyn_cast(); + result = llvm::dyn_cast(attr); if (!result) return emitError(loc, "invalid kind of attribute specified"); return success(); @@ -1126,7 +1126,7 @@ return failure(); // Check for the right kind of type. - result = type.dyn_cast(); + result = llvm::dyn_cast(type); if (!result) return emitError(loc, "invalid kind of type specified"); @@ -1158,7 +1158,7 @@ return failure(); // Check for the right kind of Type. - result = type.dyn_cast(); + result = llvm::dyn_cast(type); if (!result) return emitError(loc, "invalid kind of Type specified"); return success(); @@ -1198,7 +1198,7 @@ return failure(); // Check for the right kind of type. - result = type.dyn_cast(); + result = llvm::dyn_cast(type); if (!result) return emitError(loc, "invalid kind of type specified"); diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -509,11 +509,11 @@ template AttrClass getAttrOfType(StringAttr name) { - return getAttr(name).dyn_cast_or_null(); + return llvm::dyn_cast_or_null(getAttr(name)); } template AttrClass getAttrOfType(StringRef name) { - return getAttr(name).dyn_cast_or_null(); + return llvm::dyn_cast_or_null(getAttr(name)); } /// Return true if the operation has an attribute with the provided name, diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -433,7 +433,7 @@ static bool classof(Value value) { return llvm::isa(value.getType()); } /// Return the known Type - Ty getType() { return Value::getType().template cast(); } + Ty getType() { return llvm::cast(Value::getType()); } void setType(Ty ty) { Value::setType(ty); } }; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -395,8 +395,8 @@ wrapTypeAttributeConversion(FnT &&callback) { return [callback = std::forward(callback)]( Type type, Attribute attr) -> AttributeConversionResult { - if (T derivedType = type.dyn_cast()) { - if (A derivedAttr = attr.dyn_cast_or_null()) + if (T derivedType = llvm::dyn_cast(type)) { + if (A derivedAttr = llvm::dyn_cast_or_null(attr)) return callback(derivedType, derivedAttr); } return AttributeConversionResult::na(); diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -25,7 +25,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsALocation(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } //===----------------------------------------------------------------------===// @@ -33,7 +33,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAAffineMap(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) { @@ -41,7 +41,7 @@ } MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getValue()); + return wrap(llvm::cast(unwrap(attr)).getValue()); } //===----------------------------------------------------------------------===// @@ -49,7 +49,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAArray(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, @@ -61,11 +61,11 @@ } intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) { - return static_cast(unwrap(attr).cast().size()); + return static_cast(llvm::cast(unwrap(attr)).size()); } MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) { - return wrap(unwrap(attr).cast().getValue()[pos]); + return wrap(llvm::cast(unwrap(attr)).getValue()[pos]); } //===----------------------------------------------------------------------===// @@ -73,7 +73,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsADictionary(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, @@ -87,19 +87,19 @@ } intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) { - return static_cast(unwrap(attr).cast().size()); + return static_cast(llvm::cast(unwrap(attr)).size()); } MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr, intptr_t pos) { NamedAttribute attribute = - unwrap(attr).cast().getValue()[pos]; + llvm::cast(unwrap(attr)).getValue()[pos]; return {wrap(attribute.getName()), wrap(attribute.getValue())}; } MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, MlirStringRef name) { - return wrap(unwrap(attr).cast().get(unwrap(name))); + return wrap(llvm::cast(unwrap(attr)).get(unwrap(name))); } //===----------------------------------------------------------------------===// @@ -107,7 +107,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAFloat(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type, @@ -121,7 +121,7 @@ } double mlirFloatAttrGetValueDouble(MlirAttribute attr) { - return unwrap(attr).cast().getValueAsDouble(); + return llvm::cast(unwrap(attr)).getValueAsDouble(); } //===----------------------------------------------------------------------===// @@ -129,7 +129,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAInteger(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) { @@ -137,15 +137,15 @@ } int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) { - return unwrap(attr).cast().getInt(); + return llvm::cast(unwrap(attr)).getInt(); } int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr) { - return unwrap(attr).cast().getSInt(); + return llvm::cast(unwrap(attr)).getSInt(); } uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) { - return unwrap(attr).cast().getUInt(); + return llvm::cast(unwrap(attr)).getUInt(); } //===----------------------------------------------------------------------===// @@ -153,7 +153,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsABool(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) { @@ -161,7 +161,7 @@ } bool mlirBoolAttrGetValue(MlirAttribute attr) { - return unwrap(attr).cast().getValue(); + return llvm::cast(unwrap(attr)).getValue(); } //===----------------------------------------------------------------------===// @@ -169,7 +169,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAIntegerSet(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } //===----------------------------------------------------------------------===// @@ -177,7 +177,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAOpaque(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, @@ -189,11 +189,12 @@ } MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getDialectNamespace().strref()); + return wrap( + llvm::cast(unwrap(attr)).getDialectNamespace().strref()); } MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getAttrData()); + return wrap(llvm::cast(unwrap(attr)).getAttrData()); } //===----------------------------------------------------------------------===// @@ -201,7 +202,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAString(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) { @@ -213,7 +214,7 @@ } MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getValue()); + return wrap(llvm::cast(unwrap(attr)).getValue()); } //===----------------------------------------------------------------------===// @@ -221,7 +222,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsASymbolRef(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol, @@ -230,27 +231,30 @@ SmallVector refs; refs.reserve(numReferences); for (intptr_t i = 0; i < numReferences; ++i) - refs.push_back(unwrap(references[i]).cast()); + refs.push_back(llvm::cast(unwrap(references[i]))); auto symbolAttr = StringAttr::get(unwrap(ctx), unwrap(symbol)); return wrap(SymbolRefAttr::get(symbolAttr, refs)); } MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getRootReference().getValue()); + return wrap( + llvm::cast(unwrap(attr)).getRootReference().getValue()); } MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getLeafReference().getValue()); + return wrap( + llvm::cast(unwrap(attr)).getLeafReference().getValue()); } intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) { return static_cast( - unwrap(attr).cast().getNestedReferences().size()); + llvm::cast(unwrap(attr)).getNestedReferences().size()); } MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, intptr_t pos) { - return wrap(unwrap(attr).cast().getNestedReferences()[pos]); + return wrap( + llvm::cast(unwrap(attr)).getNestedReferences()[pos]); } //===----------------------------------------------------------------------===// @@ -258,7 +262,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) { @@ -266,7 +270,7 @@ } MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getValue()); + return wrap(llvm::cast(unwrap(attr)).getValue()); } //===----------------------------------------------------------------------===// @@ -274,7 +278,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAType(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirTypeAttrGet(MlirType type) { @@ -282,7 +286,7 @@ } MlirType mlirTypeAttrGetValue(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getValue()); + return wrap(llvm::cast(unwrap(attr)).getValue()); } //===----------------------------------------------------------------------===// @@ -290,7 +294,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAUnit(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirUnitAttrGet(MlirContext ctx) { @@ -302,24 +306,23 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAElements(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank, uint64_t *idxs) { - return wrap(unwrap(attr) - .cast() + return wrap(llvm::cast(unwrap(attr)) .getValues()[llvm::ArrayRef(idxs, rank)]); } bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, uint64_t *idxs) { - return unwrap(attr).cast().isValidIndex( - llvm::ArrayRef(idxs, rank)); + return llvm::cast(unwrap(attr)) + .isValidIndex(llvm::ArrayRef(idxs, rank)); } int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) { - return unwrap(attr).cast().getNumElements(); + return llvm::cast(unwrap(attr)).getNumElements(); } //===----------------------------------------------------------------------===// @@ -330,25 +333,25 @@ // IsA support. bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseI8Array(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseI16Array(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseI32Array(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseI64Array(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseF32Array(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseF64Array(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } //===----------------------------------------------------------------------===// @@ -394,32 +397,32 @@ // Accessors. intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) { - return unwrap(attr).cast().size(); + return llvm::cast(unwrap(attr)).size(); } //===----------------------------------------------------------------------===// // Indexed accessors. bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast()[pos]; + return llvm::cast(unwrap(attr))[pos]; } int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast()[pos]; + return llvm::cast(unwrap(attr))[pos]; } int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast()[pos]; + return llvm::cast(unwrap(attr))[pos]; } int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast()[pos]; + return llvm::cast(unwrap(attr))[pos]; } int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast()[pos]; + return llvm::cast(unwrap(attr))[pos]; } float mlirDenseF32ArrayGetElement(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast()[pos]; + return llvm::cast(unwrap(attr))[pos]; } double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast()[pos]; + return llvm::cast(unwrap(attr))[pos]; } //===----------------------------------------------------------------------===// @@ -430,13 +433,13 @@ // IsA support. bool mlirAttributeIsADenseElements(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseIntElements(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseFPElements(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } //===----------------------------------------------------------------------===// @@ -447,14 +450,14 @@ MlirAttribute const *elements) { SmallVector attributes; return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), + DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), unwrapList(numElements, elements, attributes))); } MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType, size_t rawBufferSize, const void *rawBuffer) { - auto shapedTypeCpp = unwrap(shapedType).cast(); + auto shapedTypeCpp = llvm::cast(unwrap(shapedType)); ArrayRef rawBufferCpp(static_cast(rawBuffer), rawBufferSize); bool isSplat = false; @@ -466,61 +469,61 @@ MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, MlirAttribute element) { - return wrap(DenseElementsAttr::get(unwrap(shapedType).cast(), + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), unwrap(element))); } MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType, bool element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType, uint8_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrInt8SplatGet(MlirType shapedType, int8_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType, uint32_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType, int32_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType, uint64_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType, int64_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType, float element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType, double element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType, intptr_t numElements, const int *elements) { SmallVector values(elements, elements + numElements); - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), values)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + values)); } /// Creates a dense attribute with elements of the type deduced by templates. @@ -528,7 +531,7 @@ static MlirAttribute getDenseAttribute(MlirType shapedType, intptr_t numElements, const T *elements) { - return wrap(DenseElementsAttr::get(unwrap(shapedType).cast(), + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), llvm::ArrayRef(elements, numElements))); } @@ -605,99 +608,99 @@ for (intptr_t i = 0; i < numElements; ++i) values.push_back(unwrap(strs[i])); - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), values)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + values)); } MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr, MlirType shapedType) { - return wrap(unwrap(attr).cast().reshape( - unwrap(shapedType).cast())); + return wrap(llvm::cast(unwrap(attr)) + .reshape(llvm::cast(unwrap(shapedType)))); } //===----------------------------------------------------------------------===// // Splat accessors. bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) { - return unwrap(attr).cast().isSplat(); + return llvm::cast(unwrap(attr)).isSplat(); } MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) { return wrap( - unwrap(attr).cast().getSplatValue()); + llvm::cast(unwrap(attr)).getSplatValue()); } int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } int8_t mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } uint8_t mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) { return wrap( - unwrap(attr).cast().getSplatValue()); + llvm::cast(unwrap(attr)).getSplatValue()); } //===----------------------------------------------------------------------===// // Indexed accessors. bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos) { return wrap( - unwrap(attr).cast().getValues()[pos]); + llvm::cast(unwrap(attr)).getValues()[pos]); } //===----------------------------------------------------------------------===// @@ -705,7 +708,7 @@ const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { return static_cast( - unwrap(attr).cast().getRawData().data()); + llvm::cast(unwrap(attr)).getRawData().data()); } //===----------------------------------------------------------------------===// @@ -715,7 +718,7 @@ template static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name, intptr_t numElements, const T *elements) { - return wrap(U::get(unwrap(shapedType).cast(), unwrap(name), + return wrap(U::get(llvm::cast(unwrap(shapedType)), unwrap(name), UnmanagedAsmResourceBlob::allocateInferAlign( llvm::ArrayRef(elements, numElements)))); } @@ -797,7 +800,7 @@ template static T getDenseResourceVal(MlirAttribute attr, intptr_t pos) { - return (*unwrap(attr).cast().tryGetAsArrayRef())[pos]; + return (*llvm::cast(unwrap(attr)).tryGetAsArrayRef())[pos]; } MLIR_CAPI_EXPORTED bool @@ -853,24 +856,24 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsASparseElements(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirSparseElementsAttribute(MlirType shapedType, MlirAttribute denseIndices, MlirAttribute denseValues) { - return wrap( - SparseElementsAttr::get(unwrap(shapedType).cast(), - unwrap(denseIndices).cast(), - unwrap(denseValues).cast())); + return wrap(SparseElementsAttr::get( + llvm::cast(unwrap(shapedType)), + llvm::cast(unwrap(denseIndices)), + llvm::cast(unwrap(denseValues)))); } MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getIndices()); + return wrap(llvm::cast(unwrap(attr)).getIndices()); } MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getValues()); + return wrap(llvm::cast(unwrap(attr)).getValues()); } //===----------------------------------------------------------------------===// @@ -878,7 +881,7 @@ //===----------------------------------------------------------------------===// bool mlirAttributeIsAStridedLayout(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, @@ -889,14 +892,14 @@ } int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr) { - return unwrap(attr).cast().getOffset(); + return llvm::cast(unwrap(attr)).getOffset(); } intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr) { return static_cast( - unwrap(attr).cast().getStrides().size()); + llvm::cast(unwrap(attr)).getStrides().size()); } int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getStrides()[pos]; + return llvm::cast(unwrap(attr)).getStrides()[pos]; } diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -23,7 +23,7 @@ //===----------------------------------------------------------------------===// bool mlirTypeIsAInteger(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) { @@ -39,26 +39,28 @@ } unsigned mlirIntegerTypeGetWidth(MlirType type) { - return unwrap(type).cast().getWidth(); + return llvm::cast(unwrap(type)).getWidth(); } bool mlirIntegerTypeIsSignless(MlirType type) { - return unwrap(type).cast().isSignless(); + return llvm::cast(unwrap(type)).isSignless(); } bool mlirIntegerTypeIsSigned(MlirType type) { - return unwrap(type).cast().isSigned(); + return llvm::cast(unwrap(type)).isSigned(); } bool mlirIntegerTypeIsUnsigned(MlirType type) { - return unwrap(type).cast().isUnsigned(); + return llvm::cast(unwrap(type)).isUnsigned(); } //===----------------------------------------------------------------------===// // Index type. //===----------------------------------------------------------------------===// -bool mlirTypeIsAIndex(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsAIndex(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirIndexTypeGet(MlirContext ctx) { return wrap(IndexType::get(unwrap(ctx))); @@ -136,7 +138,9 @@ // None type. //===----------------------------------------------------------------------===// -bool mlirTypeIsANone(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsANone(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirNoneTypeGet(MlirContext ctx) { return wrap(NoneType::get(unwrap(ctx))); @@ -147,7 +151,7 @@ //===----------------------------------------------------------------------===// bool mlirTypeIsAComplex(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirComplexTypeGet(MlirType elementType) { @@ -155,38 +159,41 @@ } MlirType mlirComplexTypeGetElementType(MlirType type) { - return wrap(unwrap(type).cast().getElementType()); + return wrap(llvm::cast(unwrap(type)).getElementType()); } //===----------------------------------------------------------------------===// // Shaped type. //===----------------------------------------------------------------------===// -bool mlirTypeIsAShaped(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsAShaped(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirShapedTypeGetElementType(MlirType type) { - return wrap(unwrap(type).cast().getElementType()); + return wrap(llvm::cast(unwrap(type)).getElementType()); } bool mlirShapedTypeHasRank(MlirType type) { - return unwrap(type).cast().hasRank(); + return llvm::cast(unwrap(type)).hasRank(); } int64_t mlirShapedTypeGetRank(MlirType type) { - return unwrap(type).cast().getRank(); + return llvm::cast(unwrap(type)).getRank(); } bool mlirShapedTypeHasStaticShape(MlirType type) { - return unwrap(type).cast().hasStaticShape(); + return llvm::cast(unwrap(type)).hasStaticShape(); } bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) { - return unwrap(type).cast().isDynamicDim( - static_cast(dim)); + return llvm::cast(unwrap(type)) + .isDynamicDim(static_cast(dim)); } int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) { - return unwrap(type).cast().getDimSize(static_cast(dim)); + return llvm::cast(unwrap(type)) + .getDimSize(static_cast(dim)); } int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamic; } @@ -207,7 +214,9 @@ // Vector type. //===----------------------------------------------------------------------===// -bool mlirTypeIsAVector(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsAVector(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape, MlirType elementType) { @@ -226,14 +235,16 @@ // Ranked / Unranked tensor type. //===----------------------------------------------------------------------===// -bool mlirTypeIsATensor(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsATensor(MlirType type) { + return llvm::isa(unwrap(type)); +} bool mlirTypeIsARankedTensor(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } bool mlirTypeIsAUnrankedTensor(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape, @@ -253,7 +264,7 @@ } MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type) { - return wrap(unwrap(type).cast().getEncoding()); + return wrap(llvm::cast(unwrap(type)).getEncoding()); } MlirType mlirUnrankedTensorTypeGet(MlirType elementType) { @@ -269,7 +280,9 @@ // Ranked / Unranked MemRef type. //===----------------------------------------------------------------------===// -bool mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsAMemRef(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute layout, @@ -278,7 +291,7 @@ llvm::ArrayRef(shape, static_cast(rank)), unwrap(elementType), mlirAttributeIsNull(layout) ? MemRefLayoutAttrInterface() - : unwrap(layout).cast(), + : llvm::cast(unwrap(layout)), unwrap(memorySpace))); } @@ -291,7 +304,7 @@ unwrap(elementType), mlirAttributeIsNull(layout) ? MemRefLayoutAttrInterface() - : unwrap(layout).cast(), + : llvm::cast(unwrap(layout)), unwrap(memorySpace))); } @@ -313,19 +326,19 @@ } MlirAttribute mlirMemRefTypeGetLayout(MlirType type) { - return wrap(unwrap(type).cast().getLayout()); + return wrap(llvm::cast(unwrap(type)).getLayout()); } MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type) { - return wrap(unwrap(type).cast().getLayout().getAffineMap()); + return wrap(llvm::cast(unwrap(type)).getLayout().getAffineMap()); } MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) { - return wrap(unwrap(type).cast().getMemorySpace()); + return wrap(llvm::cast(unwrap(type)).getMemorySpace()); } bool mlirTypeIsAUnrankedMemRef(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, @@ -342,14 +355,16 @@ } MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) { - return wrap(unwrap(type).cast().getMemorySpace()); + return wrap(llvm::cast(unwrap(type)).getMemorySpace()); } //===----------------------------------------------------------------------===// // Tuple type. //===----------------------------------------------------------------------===// -bool mlirTypeIsATuple(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsATuple(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements, MlirType const *elements) { @@ -359,11 +374,12 @@ } intptr_t mlirTupleTypeGetNumTypes(MlirType type) { - return unwrap(type).cast().size(); + return llvm::cast(unwrap(type)).size(); } MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) { - return wrap(unwrap(type).cast().getType(static_cast(pos))); + return wrap( + llvm::cast(unwrap(type)).getType(static_cast(pos))); } //===----------------------------------------------------------------------===// @@ -371,7 +387,7 @@ //===----------------------------------------------------------------------===// bool mlirTypeIsAFunction(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs, @@ -385,30 +401,32 @@ } intptr_t mlirFunctionTypeGetNumInputs(MlirType type) { - return unwrap(type).cast().getNumInputs(); + return llvm::cast(unwrap(type)).getNumInputs(); } intptr_t mlirFunctionTypeGetNumResults(MlirType type) { - return unwrap(type).cast().getNumResults(); + return llvm::cast(unwrap(type)).getNumResults(); } MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos) { assert(pos >= 0 && "pos in array must be positive"); - return wrap( - unwrap(type).cast().getInput(static_cast(pos))); + return wrap(llvm::cast(unwrap(type)) + .getInput(static_cast(pos))); } MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) { assert(pos >= 0 && "pos in array must be positive"); - return wrap( - unwrap(type).cast().getResult(static_cast(pos))); + return wrap(llvm::cast(unwrap(type)) + .getResult(static_cast(pos))); } //===----------------------------------------------------------------------===// // Opaque type. //===----------------------------------------------------------------------===// -bool mlirTypeIsAOpaque(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsAOpaque(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace, MlirStringRef typeData) { @@ -418,9 +436,10 @@ } MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type) { - return wrap(unwrap(type).cast().getDialectNamespace().strref()); + return wrap( + llvm::cast(unwrap(type)).getDialectNamespace().strref()); } MlirStringRef mlirOpaqueTypeGetData(MlirType type) { - return wrap(unwrap(type).cast().getTypeData()); + return wrap(llvm::cast(unwrap(type)).getTypeData()); } diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -172,7 +172,7 @@ } MlirLocation mlirLocationFromAttribute(MlirAttribute attribute) { - return wrap(Location(unwrap(attribute).cast())); + return wrap(Location(llvm::cast(unwrap(attribute)))); } MlirLocation mlirLocationFileLineColGet(MlirContext context, @@ -727,33 +727,33 @@ } bool mlirValueIsABlockArgument(MlirValue value) { - return unwrap(value).isa(); + return llvm::isa(unwrap(value)); } bool mlirValueIsAOpResult(MlirValue value) { - return unwrap(value).isa(); + return llvm::isa(unwrap(value)); } MlirBlock mlirBlockArgumentGetOwner(MlirValue value) { - return wrap(unwrap(value).cast().getOwner()); + return wrap(llvm::cast(unwrap(value)).getOwner()); } intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) { return static_cast( - unwrap(value).cast().getArgNumber()); + llvm::cast(unwrap(value)).getArgNumber()); } void mlirBlockArgumentSetType(MlirValue value, MlirType type) { - unwrap(value).cast().setType(unwrap(type)); + llvm::cast(unwrap(value)).setType(unwrap(type)); } MlirOperation mlirOpResultGetOwner(MlirValue value) { - return wrap(unwrap(value).cast().getOwner()); + return wrap(llvm::cast(unwrap(value)).getOwner()); } intptr_t mlirOpResultGetResultNumber(MlirValue value) { return static_cast( - unwrap(value).cast().getResultNumber()); + llvm::cast(unwrap(value)).getResultNumber()); } MlirType mlirValueGetType(MlirValue value) { @@ -857,7 +857,7 @@ MlirType mlirAttributeGetType(MlirAttribute attribute) { Attribute attr = unwrap(attribute); - if (auto typedAttr = attr.dyn_cast()) + if (auto typedAttr = llvm::dyn_cast(attr)) return wrap(typedAttr.getType()); return wrap(NoneType::get(attr.getContext())); } diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -48,7 +48,7 @@ Location loc = gpuOp.getLoc(); Value memref = adaptor.getMemref(); Value unconvertedMemref = gpuOp.getMemref(); - MemRefType memrefType = unconvertedMemref.getType().cast(); + MemRefType memrefType = llvm::cast(unconvertedMemref.getType()); if (chipset.majorVersion < 9) return gpuOp.emitOpError("Raw buffer ops require GCN or higher"); @@ -85,13 +85,13 @@ // so bitcast any floats to integers. Type llvmBufferValType = llvmWantedDataType; if (atomicCmpData) { - if (wantedDataType.isa()) + if (llvm::isa(wantedDataType)) return gpuOp.emitOpError("vector compare-and-swap does not exist"); - if (auto floatType = wantedDataType.dyn_cast()) + if (auto floatType = llvm::dyn_cast(wantedDataType)) llvmBufferValType = this->getTypeConverter()->convertType( rewriter.getIntegerType(floatType.getWidth())); } - if (auto dataVector = wantedDataType.dyn_cast()) { + if (auto dataVector = llvm::dyn_cast(wantedDataType)) { uint32_t elemBits = dataVector.getElementTypeBitWidth(); uint32_t totalBits = elemBits * dataVector.getNumElements(); if (totalBits > maxVectorOpWidth) @@ -312,7 +312,7 @@ static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter, Location loc, Value input) { Type inputType = input.getType(); - if (auto vectorType = inputType.dyn_cast()) { + if (auto vectorType = llvm::dyn_cast(inputType)) { if (!vectorType.getElementType().isInteger(8)) return input; int64_t numBytes = vectorType.getNumElements(); @@ -342,10 +342,10 @@ uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(), b = mfma.getBlocks(); Type sourceElem = mfma.getSourceA().getType(); - if (auto sourceType = sourceElem.dyn_cast()) + if (auto sourceType = llvm::dyn_cast(sourceElem)) sourceElem = sourceType.getElementType(); Type destElem = mfma.getDestC().getType(); - if (auto destType = destElem.dyn_cast()) + if (auto destType = llvm::dyn_cast(destElem)) destElem = destType.getElementType(); if (sourceElem.isF32() && destElem.isF32()) { @@ -406,7 +406,7 @@ return ROCDL::mfma_f32_16x16x8bf16::getOperationName(); } - if (sourceElem.isa() && destElem.isInteger(32)) { + if (llvm::isa(sourceElem) && destElem.isInteger(32)) { if (m == 32 && n == 32 && k == 4 && b == 2) return ROCDL::mfma_i32_32x32x4i8::getOperationName(); if (m == 16 && n == 16 && k == 4 && b == 4) @@ -435,7 +435,7 @@ // Known to be correct because there are no scalar f8 instructions and // because a length mismatch will have been caught by the verifier. Type sourceBElem = - mfma.getSourceB().getType().cast().getElementType(); + llvm::cast(mfma.getSourceB().getType()).getElementType(); if (m == 16 && n == 16 && k == 32 && b == 1) { if (sourceBElem.isFloat8E5M2FNUZ()) return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName(); @@ -453,7 +453,7 @@ if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset.minorVersion >= 0x40) { Type sourceBElem = - mfma.getSourceB().getType().cast().getElementType(); + llvm::cast(mfma.getSourceB().getType()).getElementType(); if (m == 16 && n == 16 && k == 32 && b == 1) { if (sourceBElem.isFloat8E5M2FNUZ()) return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName(); diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -225,8 +225,8 @@ Attribute reduction = std::get<0>(pair); Type resultType = std::get<1>(pair); std::optional reductionOp = - arith::symbolizeAtomicRMWKind( - static_cast(reduction.cast().getInt())); + arith::symbolizeAtomicRMWKind(static_cast( + llvm::cast(reduction).getInt())); assert(reductionOp && "Reduction operation cannot be of None Type"); arith::AtomicRMWKind reductionOpValue = *reductionOp; identityVals.push_back( @@ -246,7 +246,7 @@ // For each of the reduction operations get the respective mlir::Value. std::optional reductionOp = arith::symbolizeAtomicRMWKind( - reductions[i].cast().getInt()); + llvm::cast(reductions[i]).getInt()); assert(reductionOp && "Reduction Operation cannot be of None Type"); arith::AtomicRMWKind reductionOpValue = *reductionOp; rewriter.setInsertionPoint(&parOp.getBody()->back()); diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -210,7 +210,7 @@ // Handle the scalar and 1D vector cases. Type operandType = adaptor.getIn().getType(); - if (!operandType.isa()) { + if (!llvm::isa(operandType)) { Type targetType = this->typeConverter->convertType(resultType); if (targetBits < sourceBits) rewriter.replaceOpWithNewOp(op, targetType, @@ -220,7 +220,7 @@ return success(); } - if (!resultType.isa()) + if (!llvm::isa(resultType)) return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( @@ -255,7 +255,7 @@ Location loc = op.getLoc(); // Handle the scalar and 1D vector cases. - if (!operandType.isa()) { + if (!llvm::isa(operandType)) { Type newOverflowType = typeConverter->convertType(overflowResultType); Type structType = LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType}); @@ -269,7 +269,7 @@ return success(); } - if (!sumResultType.isa()) + if (!llvm::isa(sumResultType)) return rewriter.notifyMatchFailure(loc, "expected vector result types"); return rewriter.notifyMatchFailure(loc, @@ -295,16 +295,16 @@ // matching extended multiplication intrinsic, perform regular multiplication // on operands zero-extended to i(2*N) bits, and truncate the results back to // iN types. - if (!resultType.isa()) { + if (!llvm::isa(resultType)) { // Shift amount necessary to extract the high bits from widened result. TypedAttr shiftValAttr; - if (auto intTy = resultType.dyn_cast()) { + if (auto intTy = llvm::dyn_cast(resultType)) { unsigned resultBitwidth = intTy.getWidth(); auto attrTy = rewriter.getIntegerType(resultBitwidth * 2); shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth); } else { - auto vecTy = resultType.cast(); + auto vecTy = llvm::cast(resultType); unsigned resultBitwidth = vecTy.getElementTypeBitWidth(); auto attrTy = VectorType::get( vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2)); @@ -330,7 +330,7 @@ return success(); } - if (!resultType.isa()) + if (!llvm::isa(resultType)) return rewriter.notifyMatchFailure(op, "expected vector result type"); return rewriter.notifyMatchFailure(op, @@ -355,7 +355,7 @@ Type resultType = op.getResult().getType(); // Handle the scalar and 1D vector cases. - if (!operandType.isa()) { + if (!llvm::isa(operandType)) { rewriter.replaceOpWithNewOp( op, typeConverter->convertType(resultType), convertCmpPredicate(op.getPredicate()), @@ -363,7 +363,7 @@ return success(); } - if (!resultType.isa()) + if (!llvm::isa(resultType)) return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( @@ -389,7 +389,7 @@ Type resultType = op.getResult().getType(); // Handle the scalar and 1D vector cases. - if (!operandType.isa()) { + if (!llvm::isa(operandType)) { rewriter.replaceOpWithNewOp( op, typeConverter->convertType(resultType), convertCmpPredicate(op.getPredicate()), @@ -397,7 +397,7 @@ return success(); } - if (!resultType.isa()) + if (!llvm::isa(resultType)) return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -366,12 +366,12 @@ static std::optional convertAsyncTypes(Type type, bool useOpaquePointers) { - if (type.isa()) + if (llvm::isa(type)) return AsyncAPI::opaquePointerType(type.getContext(), useOpaquePointers); - if (type.isa()) + if (llvm::isa(type)) return AsyncAPI::tokenType(type.getContext()); - if (type.isa()) + if (llvm::isa(type)) return AsyncAPI::opaquePointerType(type.getContext(), useOpaquePointers); return std::nullopt; @@ -656,14 +656,14 @@ Type resultType = op->getResultTypes()[0]; // Tokens creation maps to a simple function call. - if (resultType.isa()) { + if (llvm::isa(resultType)) { rewriter.replaceOpWithNewOp( op, kCreateToken, converter->convertType(resultType)); return success(); } // To create a value we need to compute the storage requirement. - if (auto value = resultType.dyn_cast()) { + if (auto value = llvm::dyn_cast(resultType)) { // Returns the size requirements for the async value storage. auto sizeOf = [&](ValueType valueType) -> Value { auto loc = op->getLoc(); @@ -994,7 +994,7 @@ matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Currently we can only add tokens to the group. - if (!op.getOperand().getType().isa()) + if (!llvm::isa(op.getOperand().getType())) return rewriter.notifyMatchFailure(op, "only token type is supported"); // Replace with a runtime API function call. diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp --- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp +++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp @@ -41,11 +41,11 @@ ConversionPatternRewriter &rewriter) const override { // Check for unranked memref types which are currently not supported. Type type = op.getType(); - if (type.isa()) { + if (llvm::isa(type)) { return rewriter.notifyMatchFailure( op, "UnrankedMemRefType is not supported."); } - MemRefType memrefType = type.cast(); + MemRefType memrefType = llvm::cast(type); MemRefLayoutAttrInterface layout; auto allocType = MemRefType::get(memrefType.getShape(), memrefType.getElementType(), diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -144,13 +144,13 @@ size_t argOffset = resultStructType ? 1 : 0; for (auto [index, argType] : llvm::enumerate(type.getInputs())) { Value arg = wrapperFuncOp.getArgument(index + argOffset); - if (auto memrefType = argType.dyn_cast()) { + if (auto memrefType = llvm::dyn_cast(argType)) { Value loaded = rewriter.create( loc, typeConverter.convertType(memrefType), arg); MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args); continue; } - if (argType.isa()) { + if (llvm::isa(argType)) { Value loaded = rewriter.create( loc, typeConverter.convertType(argType), arg); UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args); @@ -219,7 +219,7 @@ if (resultStructType) { // Allocate the struct on the stack and pass the pointer. Type resultType = - wrapperType.cast().getParamType(0); + llvm::cast(wrapperType).getParamType(0); Value one = builder.create( loc, typeConverter.convertType(builder.getIndexType()), builder.getIntegerAttr(builder.getIndexType(), 1)); @@ -233,8 +233,8 @@ for (Type input : type.getInputs()) { Value arg; int numToDrop = 1; - auto memRefType = input.dyn_cast(); - auto unrankedMemRefType = input.dyn_cast(); + auto memRefType = llvm::dyn_cast(input); + auto unrankedMemRefType = llvm::dyn_cast(input); if (memRefType || unrankedMemRefType) { numToDrop = memRefType ? MemRefDescriptor::getNumUnpackedValues(memRefType) @@ -301,9 +301,9 @@ // Unranked memrefs are not supported in the bare pointer calling // convention. We should have bailed out before in the presence of // unranked memrefs. - assert(!argTy.isa() && + assert(!llvm::isa(argTy) && "Unranked memref is not supported"); - auto memrefTy = argTy.dyn_cast(); + auto memrefTy = llvm::dyn_cast(argTy); if (!memrefTy) continue; @@ -360,18 +360,18 @@ } if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) { SmallVector newArgAttrs( - llvmType.cast().getNumParams()); + llvm::cast(llvmType).getNumParams()); for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { // Some LLVM IR attribute have a type attached to them. During FuncOp -> // LLVMFuncOp conversion these types may have changed. Account for that // change by converting attributes' types as well. SmallVector convertedAttrs; - auto attrsDict = argAttrDicts[i].cast(); + auto attrsDict = llvm::cast(argAttrDicts[i]); convertedAttrs.reserve(attrsDict.size()); for (const NamedAttribute &attr : attrsDict) { const auto convert = [&](const NamedAttribute &attr) { return TypeAttr::get(getTypeConverter()->convertType( - attr.getValue().cast().getValue())); + llvm::cast(attr.getValue()).getValue())); }; if (attr.getName().getValue() == LLVM::LLVMDialect::getByValAttrName()) { @@ -417,8 +417,8 @@ // functions have linkage. LLVM::Linkage linkage = LLVM::Linkage::External; if (funcOp->hasAttr(linkageAttrName)) { - auto attr = - funcOp->getAttr(linkageAttrName).dyn_cast(); + auto attr = llvm::dyn_cast( + funcOp->getAttr(linkageAttrName)); if (!attr) { funcOp->emitError() << "Contains " << linkageAttrName << " attribute not of type LLVM::LinkageAttr"; @@ -545,7 +545,7 @@ if (useBarePtrCallConv) { for (auto it : callOp->getOperands()) { Type operandType = it.getType(); - if (operandType.isa()) { + if (llvm::isa(operandType)) { // Unranked memref is not supported in the bare pointer calling // convention. return failure(); @@ -669,11 +669,12 @@ for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { Type oldTy = std::get<0>(it).getType(); Value newOperand = std::get<1>(it); - if (oldTy.isa() && getTypeConverter()->canConvertToBarePtr( - oldTy.cast())) { + if (llvm::isa(oldTy) && + getTypeConverter()->canConvertToBarePtr( + llvm::cast(oldTy))) { MemRefDescriptor memrefDesc(newOperand); newOperand = memrefDesc.allocatedPtr(rewriter, loc); - } else if (oldTy.isa()) { + } else if (llvm::isa(oldTy)) { // Unranked memref is not supported in the bare pointer calling // convention. return failure(); diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -26,22 +26,20 @@ for (const auto &en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) { BlockArgument attribution = en.value(); - auto type = attribution.getType().dyn_cast(); + auto type = llvm::dyn_cast(attribution.getType()); assert(type && type.hasStaticShape() && "unexpected type in attribution"); uint64_t numElements = type.getNumElements(); auto elementType = - typeConverter->convertType(type.getElementType()).template cast(); + llvm::cast(typeConverter->convertType(type.getElementType())); auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements); std::string name = std::string( llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index())); uint64_t alignment = 0; - if (auto alignAttr = - gpuFuncOp - .getWorkgroupAttributionAttr( - en.index(), LLVM::LLVMDialect::getAlignAttrName()) - .dyn_cast_or_null()) + if (auto alignAttr = llvm::dyn_cast_or_null( + gpuFuncOp.getWorkgroupAttributionAttr( + en.index(), LLVM::LLVMDialect::getAlignAttrName()))) alignment = alignAttr.getInt(); auto globalOp = rewriter.create( gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, @@ -100,7 +98,7 @@ global.getAddrSpace()), global.getSymNameAttr()); auto elementType = - global.getType().cast().getElementType(); + llvm::cast(global.getType()).getElementType(); Value memory = rewriter.create( loc, getTypeConverter()->getPointerType(elementType, @@ -112,7 +110,7 @@ // otherwise necessary given that memref sizes are fixed, but we can try // and canonicalize that away later. Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()]; - auto type = attribution.getType().cast(); + auto type = llvm::cast(attribution.getType()); auto descr = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), type, memory); signatureConversion.remapInput(numProperArguments + en.index(), descr); @@ -123,7 +121,7 @@ auto int64Ty = IntegerType::get(rewriter.getContext(), 64); for (const auto &en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) { Value attribution = en.value(); - auto type = attribution.getType().cast(); + auto type = llvm::cast(attribution.getType()); assert(type && type.hasStaticShape() && "unexpected type in attribution"); // Explicitly drop memory space when lowering private memory @@ -135,11 +133,9 @@ Value numElements = rewriter.create( gpuFuncOp.getLoc(), int64Ty, type.getNumElements()); uint64_t alignment = 0; - if (auto alignAttr = - gpuFuncOp - .getPrivateAttributionAttr( - en.index(), LLVM::LLVMDialect::getAlignAttrName()) - .dyn_cast_or_null()) + if (auto alignAttr = llvm::dyn_cast_or_null( + gpuFuncOp.getPrivateAttributionAttr( + en.index(), LLVM::LLVMDialect::getAlignAttrName()))) alignment = alignAttr.getInt(); Value allocated = rewriter.create( gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment); @@ -164,7 +160,7 @@ OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front()); for (const auto &en : llvm::enumerate(gpuFuncOp.getArgumentTypes())) { - auto memrefTy = en.value().dyn_cast(); + auto memrefTy = llvm::dyn_cast(en.value()); if (!memrefTy) continue; assert(memrefTy.hasStaticShape() && @@ -302,7 +298,7 @@ rewriter.create(loc, llvmI32, numArgsThisCall)); for (size_t i = group; i < bound; ++i) { Value arg = adaptor.getArgs()[i]; - if (auto floatType = arg.getType().dyn_cast()) { + if (auto floatType = llvm::dyn_cast(arg.getType())) { if (!floatType.isF64()) arg = rewriter.create( loc, typeConverter->convertType(rewriter.getF64Type()), arg); @@ -428,7 +424,7 @@ Type type = arg.getType(); Value promotedArg = arg; assert(type.isIntOrFloat()); - if (type.isa()) { + if (llvm::isa(type)) { type = rewriter.getF64Type(); promotedArg = rewriter.create(loc, type, arg); } @@ -462,14 +458,15 @@ LLVMTypeConverter &converter) { TypeRange operandTypes(operands); if (llvm::none_of(operandTypes, - [](Type type) { return type.isa(); })) { + [](Type type) { return llvm::isa(type); })) { return rewriter.notifyMatchFailure(op, "expected vector operand"); } if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0) return rewriter.notifyMatchFailure(op, "expected no region/successor"); if (op->getNumResults() != 1) return rewriter.notifyMatchFailure(op, "expected single result"); - VectorType vectorType = op->getResult(0).getType().dyn_cast(); + VectorType vectorType = + llvm::dyn_cast(op->getResult(0).getType()); if (!vectorType) return rewriter.notifyMatchFailure(op, "expected vector result"); @@ -482,7 +479,7 @@ for (int64_t i = 0; i < vectorType.getNumElements(); ++i) { Value index = rewriter.create(loc, indexType, i); auto extractElement = [&](Value operand) -> Value { - if (!operand.getType().isa()) + if (!llvm::isa(operand.getType())) return operand; return rewriter.create(loc, operand, index); }; diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -454,7 +454,8 @@ Location loc = op->getLoc(); auto memRefType = hostRegisterOp.getValue().getType(); - auto elementType = memRefType.cast().getElementType(); + auto elementType = + llvm::cast(memRefType).getElementType(); auto elementSize = getSizeInBytes(loc, elementType, rewriter); auto arguments = getTypeConverter()->promoteOperands( @@ -476,7 +477,8 @@ Location loc = op->getLoc(); auto memRefType = hostUnregisterOp.getValue().getType(); - auto elementType = memRefType.cast().getElementType(); + auto elementType = + llvm::cast(memRefType).getElementType(); auto elementSize = getSizeInBytes(loc, elementType, rewriter); auto arguments = getTypeConverter()->promoteOperands( @@ -555,7 +557,7 @@ } static bool isGpuAsyncTokenType(Value value) { - return value.getType().isa(); + return llvm::isa(value.getType()); } // Converts !gpu.async.token operands of `async.yield` to runtime calls. The @@ -591,7 +593,7 @@ // Returns whether `value` is the result of an LLVM::CallOp to `functionName`. static bool isDefinedByCallTo(Value value, StringRef functionName) { - assert(value.getType().isa()); + assert(llvm::isa(value.getType())); if (auto defOp = value.getDefiningOp()) return defOp.getCallee()->equals(functionName); return false; @@ -862,7 +864,7 @@ LLVM::LLVMPointerType destinationType, Value sourcePtr, LLVMTypeConverter &typeConverter) { - auto sourceTy = sourcePtr.getType().cast(); + auto sourceTy = llvm::cast(sourcePtr.getType()); if (destinationType.getAddressSpace() != sourceTy.getAddressSpace()) sourcePtr = rewriter.create( loc, @@ -879,7 +881,7 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( gpu::MemcpyOp memcpyOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto memRefType = memcpyOp.getSrc().getType().cast(); + auto memRefType = llvm::cast(memcpyOp.getSrc().getType()); if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) || !isConvertibleAndHasIdentityMaps(memRefType) || @@ -919,7 +921,7 @@ LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite( gpu::MemsetOp memsetOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto memRefType = memsetOp.getDst().getType().cast(); + auto memRefType = llvm::cast(memsetOp.getDst().getType()); if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) || !isConvertibleAndHasIdentityMaps(memRefType) || diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -55,7 +55,7 @@ Type resultType = castedOperands.front().getType(); Type funcType = getFunctionType(resultType, castedOperands); StringRef funcName = getFunctionName( - funcType.cast().getReturnType()); + llvm::cast(funcType).getReturnType()); if (funcName.empty()) return failure(); @@ -78,7 +78,7 @@ private: Value maybeCast(Value operand, PatternRewriter &rewriter) const { Type type = operand.getType(); - if (!type.isa()) + if (!llvm::isa(type)) return operand; return rewriter.create( @@ -91,9 +91,9 @@ } StringRef getFunctionName(Type type) const { - if (type.isa()) + if (llvm::isa(type)) return f32Func; - if (type.isa()) + if (llvm::isa(type)) return f64Func; return ""; } diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -90,8 +90,8 @@ NVVM::MMALayout layout = subgroupMmaLoadMatrixOp.getTranspose() ? NVVM::MMALayout::col : NVVM::MMALayout::row; - gpu::MMAMatrixType retType = - subgroupMmaLoadMatrixOp.getRes().getType().cast(); + gpu::MMAMatrixType retType = llvm::cast( + subgroupMmaLoadMatrixOp.getRes().getType()); ArrayRef retTypeShape = retType.getShape(); int64_t m = 0; int64_t n = 0; @@ -123,7 +123,8 @@ // Create nvvm.mma_load op according to the operand types. Value dataPtr = getStridedElementPtr( loc, - subgroupMmaLoadMatrixOp.getSrcMemref().getType().cast(), + llvm::cast( + subgroupMmaLoadMatrixOp.getSrcMemref().getType()), adaptor.getSrcMemref(), adaptor.getIndices(), rewriter); Value leadingDim = rewriter.create( @@ -157,8 +158,8 @@ SmallVector storeOpOperands; // Get the shape of the MMAMatrix type being stored. The shape will // choose which intrinsic this op will be lowered to. - gpu::MMAMatrixType srcType = - subgroupMmaStoreMatrixOp.getSrc().getType().cast(); + gpu::MMAMatrixType srcType = llvm::cast( + subgroupMmaStoreMatrixOp.getSrc().getType()); ArrayRef srcTypeShape = srcType.getShape(); NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose() ? NVVM::MMALayout::col @@ -170,7 +171,8 @@ if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0) return rewriter.notifyMatchFailure(op, kInvalidCaseStr); - auto matrixType = adaptor.getSrc().getType().cast(); + auto matrixType = + llvm::cast(adaptor.getSrc().getType()); for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) { Value toUse = rewriter.create(loc, adaptor.getSrc(), i); @@ -179,7 +181,8 @@ Value dataPtr = getStridedElementPtr( loc, - subgroupMmaStoreMatrixOp.getDstMemref().getType().cast(), + llvm::cast( + subgroupMmaStoreMatrixOp.getDstMemref().getType()), adaptor.getDstMemref(), adaptor.getIndices(), rewriter); Value leadingDim = rewriter.create( loc, rewriter.getI32Type(), @@ -214,7 +217,7 @@ SmallVector unpackedOps; auto unpackOp = [&](Value operand) { - auto structType = operand.getType().cast(); + auto structType = llvm::cast(operand.getType()); for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) { Value toUse = rewriter.create(loc, operand, i); unpackedOps.push_back(toUse); @@ -224,10 +227,10 @@ // Get the shapes of the MMAMatrix type being used. The shapes will // choose which intrinsic this op will be lowered to. gpu::MMAMatrixType aType = - subgroupMmaComputeOp.getOpA().getType().cast(); + llvm::cast(subgroupMmaComputeOp.getOpA().getType()); ArrayRef aTypeShape = aType.getShape(); gpu::MMAMatrixType cType = - subgroupMmaComputeOp.getOpC().getType().cast(); + llvm::cast(subgroupMmaComputeOp.getOpC().getType()); ArrayRef cTypeShape = cType.getShape(); int64_t m = cTypeShape[0]; int64_t n = cTypeShape[1]; @@ -244,8 +247,8 @@ destType) == 0) return rewriter.notifyMatchFailure(op, kInvalidCaseStr); - NVVM::MMATypes bElementType = getElementType( - subgroupMmaComputeOp.getOpB().getType().cast()); + NVVM::MMATypes bElementType = getElementType(llvm::cast( + subgroupMmaComputeOp.getOpB().getType())); if (bElementType != sourceType) return rewriter.notifyMatchFailure( op, "WMMA compute op input matrix element types must match."); @@ -277,9 +280,9 @@ Location loc = subgroupMmaConstantOp.getLoc(); Value cst = adaptor.getOperands()[0]; LLVM::LLVMStructType type = convertMMAToLLVMType( - subgroupMmaConstantOp.getType().cast()); + llvm::cast(subgroupMmaConstantOp.getType())); // If the element type is a vector create a vector from the operand. - if (auto vecType = type.getBody()[0].dyn_cast()) { + if (auto vecType = llvm::dyn_cast(type.getBody()[0])) { Value vecCst = rewriter.create(loc, vecType); for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) { Value idx = rewriter.create( @@ -301,9 +304,9 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, Value rhs, bool isMin) { - auto floatType = getElementTypeOrSelf(lhs.getType()).cast(); + auto floatType = llvm::cast(getElementTypeOrSelf(lhs.getType())); Type i1Type = builder.getI1Type(); - if (auto vecType = lhs.getType().dyn_cast()) + if (auto vecType = llvm::dyn_cast(lhs.getType())) i1Type = VectorType::get(vecType.getShape(), i1Type); Value cmp = builder.create( loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, @@ -355,7 +358,7 @@ Location loc = subgroupMmaElementwiseOp.getLoc(); size_t numOperands = adaptor.getOperands().size(); LLVM::LLVMStructType destType = convertMMAToLLVMType( - subgroupMmaElementwiseOp.getType().cast()); + llvm::cast(subgroupMmaElementwiseOp.getType())); Value matrixStruct = rewriter.create(loc, destType); for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) { SmallVector extractedOperands; diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -54,7 +54,7 @@ static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) { bool canBeBare = true; for (Type type : func.getArgumentTypes()) - if (auto memrefTy = type.dyn_cast()) + if (auto memrefTy = llvm::dyn_cast(type)) canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy); return canBeBare; } @@ -166,9 +166,8 @@ // Manually rewrite known block size attributes so the LLVMIR translation // infrastructure can pick them up. m.walk([ctx](LLVM::LLVMFuncOp op) { - if (auto blockSizes = - op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()) - .dyn_cast_or_null()) { + if (auto blockSizes = llvm::dyn_cast_or_null( + op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) { op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(), blockSizes); // Also set up the rocdl.flat_work_group_size attribute to prevent diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -142,11 +142,11 @@ /// Returns a string representation from the given `type`. StringRef stringifyType(Type type) { - if (type.isa()) + if (llvm::isa(type)) return "Float"; - if (type.isa()) + if (llvm::isa(type)) return "Half"; - if (auto intType = type.dyn_cast()) { + if (auto intType = llvm::dyn_cast(type)) { if (intType.getWidth() == 32) return "Int32"; if (intType.getWidth() == 16) @@ -282,7 +282,7 @@ llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str(); // Special case for fp16 type. Since it is not a supported type in C we use // int16_t and bitcast the descriptor. - if (!useOpaquePointers && type.isa()) { + if (!useOpaquePointers && llvm::isa(type)) { auto memRefTy = getMemRefType(rank, IntegerType::get(&getContext(), 16)); ptrToMemRefDescriptor = builder.create( loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor); @@ -328,8 +328,7 @@ rank = 0; return success(); } - rank = llvmDescriptorTy.getBody()[3] - .cast() + rank = llvm::cast(llvmDescriptorTy.getBody()[3]) .getNumElements(); return success(); } @@ -375,7 +374,7 @@ for (auto type : types) { std::string fnName = "bindMemRef" + std::to_string(i) + "D" + std::string(stringifyType(type)); - if (type.isa()) + if (llvm::isa(type)) type = IntegerType::get(&getContext(), 16); if (!module.lookupSymbol(fnName)) { auto fnType = LLVM::LLVMFunctionType::get( diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp --- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -24,8 +24,7 @@ MemRefDescriptor::MemRefDescriptor(Value descriptor) : StructBuilder(descriptor) { assert(value != nullptr && "value cannot be null"); - indexType = value.getType() - .cast() + indexType = llvm::cast(value.getType()) .getBody()[kOffsetPosInMemRefDescriptor]; } @@ -193,10 +192,10 @@ } LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() { - return value.getType() - .cast() - .getBody()[kAlignedPtrPosInMemRefDescriptor] - .cast(); + return llvm::cast( + value.getType() + .cast() + .getBody()[kAlignedPtrPosInMemRefDescriptor]); } Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc, diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -235,7 +235,7 @@ SmallVector unrankedMemrefs; SmallVector unrankedAddressSpaces; for (unsigned i = 0, e = operands.size(); i < e; ++i) { - if (auto memRefType = origTypes[i].dyn_cast()) { + if (auto memRefType = llvm::dyn_cast(origTypes[i])) { unrankedMemrefs.emplace_back(operands[i]); FailureOr addressSpace = getTypeConverter()->getMemRefAddressSpace(memRefType); @@ -276,7 +276,7 @@ unsigned unrankedMemrefPos = 0; for (unsigned i = 0, e = operands.size(); i < e; ++i) { Type type = origTypes[i]; - if (!type.isa()) + if (!llvm::isa(type)) continue; Value allocationSize = sizes[unrankedMemrefPos++]; UnrankedMemRefDescriptor desc(operands[i]); diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -260,7 +260,7 @@ if (!resultType) return {}; - auto structType = resultType.dyn_cast(); + auto structType = llvm::dyn_cast(resultType); if (structType) { // Struct types cannot be safely returned via C interface. Make this a // pointer argument, instead. @@ -272,7 +272,7 @@ auto converted = convertType(t); if (!converted || !LLVM::isCompatibleType(converted)) return {}; - if (t.isa()) + if (llvm::isa(t)) converted = getPointerType(converted); inputs.push_back(converted); } @@ -412,13 +412,13 @@ // Check if a memref type can be converted to a bare pointer. bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) { - if (type.isa()) + if (llvm::isa(type)) // Unranked memref is not supported in the bare pointer calling convention. return false; // Check that the memref has static shape, strides and offset. Otherwise, it // cannot be lowered to a bare pointer. - auto memrefTy = type.cast(); + auto memrefTy = llvm::cast(type); if (!memrefTy.hasStaticShape()) return false; @@ -476,7 +476,7 @@ Type LLVMTypeConverter::convertCallingConventionType(Type type, bool useBarePtrCallConv) { if (useBarePtrCallConv) - if (auto memrefTy = type.dyn_cast()) + if (auto memrefTy = llvm::dyn_cast(type)) return convertMemRefToBarePtr(memrefTy); return convertType(type); @@ -491,7 +491,7 @@ assert(stdTypes.size() == values.size() && "The number of types and values doesn't match"); for (unsigned i = 0, end = values.size(); i < end; ++i) - if (auto memrefTy = stdTypes[i].dyn_cast()) + if (auto memrefTy = llvm::dyn_cast(stdTypes[i])) values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this, memrefTy, values[i]); } @@ -569,19 +569,19 @@ if (useBarePtrCallConv) { // For the bare-ptr calling convention, we only have to extract the // aligned pointer of a memref. - if (auto memrefType = operand.getType().dyn_cast()) { + if (auto memrefType = llvm::dyn_cast(operand.getType())) { MemRefDescriptor desc(llvmOperand); llvmOperand = desc.alignedPtr(builder, loc); - } else if (operand.getType().isa()) { + } else if (llvm::isa(operand.getType())) { llvm_unreachable("Unranked memrefs are not supported"); } } else { - if (operand.getType().isa()) { + if (llvm::isa(operand.getType())) { UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand, promotedOperands); continue; } - if (auto memrefType = operand.getType().dyn_cast()) { + if (auto memrefType = llvm::dyn_cast(operand.getType())) { MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType, promotedOperands); continue; @@ -600,7 +600,7 @@ LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, SmallVectorImpl &result) { - if (auto memref = type.dyn_cast()) { + if (auto memref = llvm::dyn_cast(type)) { // In signatures, Memref descriptors are expanded into lists of // non-aggregate values. auto converted = @@ -610,7 +610,7 @@ result.append(converted.begin(), converted.end()); return success(); } - if (type.isa()) { + if (llvm::isa(type)) { auto converted = converter.getUnrankedMemRefDescriptorFields(); if (converted.empty()) return failure(); diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -27,10 +27,10 @@ } info.arraySizes.reserve(vectorType.getRank() - 1); auto llvmTy = info.llvmNDVectorTy; - while (llvmTy.isa()) { + while (llvm::isa(llvmTy)) { info.arraySizes.push_back( - llvmTy.cast().getNumElements()); - llvmTy = llvmTy.cast().getElementType(); + llvm::cast(llvmTy).getNumElements()); + llvmTy = llvm::cast(llvmTy).getElementType(); } if (!LLVM::isCompatibleVectorType(llvmTy)) return info; @@ -81,7 +81,7 @@ Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, std::function createOperand, ConversionPatternRewriter &rewriter) { - auto resultNDVectorType = op->getResult(0).getType().cast(); + auto resultNDVectorType = llvm::cast(op->getResult(0).getType()); auto resultTypeInfo = extractNDVectorTypeInfo(resultNDVectorType, typeConverter); auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy; @@ -114,7 +114,7 @@ return failure(); auto llvmNDVectorTy = operands[0].getType(); - if (!llvmNDVectorTy.isa()) + if (!llvm::isa(llvmNDVectorTy)) return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter, rewriter); diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -42,7 +42,7 @@ // The underlying descriptor type (e.g. LLVM) does not have layout // information. Canonicalizing the type at the level of std when going into // a library call avoids needing to introduce DialectCastOp. - if (auto memrefType = type.dyn_cast()) + if (auto memrefType = llvm::dyn_cast(type)) result.push_back(makeStridedLayoutDynamic(memrefType)); else result.push_back(type); @@ -96,7 +96,7 @@ SmallVector res; res.reserve(operands.size()); for (auto op : operands) { - auto memrefType = op.getType().dyn_cast(); + auto memrefType = llvm::dyn_cast(op.getType()); if (!memrefType) { res.push_back(op); continue; diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp --- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp +++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp @@ -106,7 +106,7 @@ VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const { Type opType = op.getType(); Location loc = op.getLoc(); - auto vecType = opType.template dyn_cast(); + auto vecType = llvm::dyn_cast(opType); if (!vecType) return rewriter.notifyMatchFailure(op, "not a vector operation"); @@ -117,7 +117,7 @@ Type resultElementType = vecType.getElementType(); Attribute initValueAttr; - if (resultElementType.isa()) + if (llvm::isa(resultElementType)) initValueAttr = FloatAttr::get(resultElementType, 0.0); else initValueAttr = IntegerAttr::get(resultElementType, 0); @@ -183,7 +183,7 @@ /// } /// } static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) { - assert(elementType.isa() && + assert(llvm::isa(elementType) && "non-integer element type for IPowIOp"); ImplicitLocOpBuilder builder = @@ -361,7 +361,7 @@ LogicalResult IPowIOpLowering::matchAndRewrite(math::IPowIOp op, PatternRewriter &rewriter) const { - auto baseType = op.getOperands()[0].getType().dyn_cast(); + auto baseType = llvm::dyn_cast(op.getOperands()[0].getType()); if (!baseType) return rewriter.notifyMatchFailure(op, "non-integer base operand"); @@ -411,8 +411,8 @@ /// } static func::FuncOp createElementFPowIFunc(ModuleOp *module, FunctionType funcType) { - auto baseType = funcType.getInput(0).cast(); - auto powType = funcType.getInput(1).cast(); + auto baseType = llvm::cast(funcType.getInput(0)); + auto powType = llvm::cast(funcType.getInput(1)); ImplicitLocOpBuilder builder = ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody()); @@ -586,7 +586,7 @@ LogicalResult FPowIOpLowering::matchAndRewrite(math::FPowIOp op, PatternRewriter &rewriter) const { - if (op.getType().template dyn_cast()) + if (llvm::dyn_cast(op.getType())) return rewriter.notifyMatchFailure(op, "non-scalar operation"); FunctionType funcType = getElementalFuncTypeForOp(op); @@ -649,7 +649,7 @@ /// return %out: i32 /// } static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { - if (!elementType.isa()) { + if (!llvm::isa(elementType)) { LLVM_DEBUG({ DBGS() << "non-integer element type for CtlzFunc; type was: "; elementType.print(llvm::dbgs()); @@ -751,7 +751,7 @@ /// operation. LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op, PatternRewriter &rewriter) const { - if (op.getType().template dyn_cast()) + if (llvm::dyn_cast(op.getType())) return rewriter.notifyMatchFailure(op, "non-scalar operation"); Type type = getElementTypeOrSelf(op.getResult().getType()); @@ -794,7 +794,7 @@ bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) { auto expTy = - getElementTypeOrSelf(op.getRhs().getType()).dyn_cast(); + llvm::dyn_cast(getElementTypeOrSelf(op.getRhs().getType())); return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent); } diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -79,14 +79,14 @@ auto resultType = op.getResult().getType(); auto boolZero = rewriter.getBoolAttr(false); - if (!operandType.template isa()) { + if (!llvm::isa(operandType)) { LLVM::ConstantOp zero = rewriter.create(loc, boolZero); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getOperand(), zero); return success(); } - auto vectorType = resultType.template dyn_cast(); + auto vectorType = llvm::dyn_cast(resultType); if (!vectorType) return failure(); @@ -122,17 +122,18 @@ auto loc = op.getLoc(); auto resultType = op.getResult().getType(); - auto floatType = getElementTypeOrSelf(resultType).cast(); + auto floatType = llvm::cast(getElementTypeOrSelf(resultType)); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); ConvertFastMath expAttrs(op); ConvertFastMath subAttrs(op); - if (!operandType.isa()) { + if (!llvm::isa(operandType)) { LLVM::ConstantOp one; if (LLVM::isCompatibleVectorType(operandType)) { one = rewriter.create( loc, operandType, - SplatElementsAttr::get(resultType.cast(), floatOne)); + SplatElementsAttr::get(llvm::cast(resultType), + floatOne)); } else { one = rewriter.create(loc, operandType, floatOne); } @@ -143,7 +144,7 @@ return success(); } - auto vectorType = resultType.dyn_cast(); + auto vectorType = llvm::dyn_cast(resultType); if (!vectorType) return rewriter.notifyMatchFailure(op, "expected vector result type"); @@ -180,17 +181,17 @@ auto loc = op.getLoc(); auto resultType = op.getResult().getType(); - auto floatType = getElementTypeOrSelf(resultType).cast(); + auto floatType = llvm::cast(getElementTypeOrSelf(resultType)); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); ConvertFastMath addAttrs(op); ConvertFastMath logAttrs(op); - if (!operandType.isa()) { + if (!llvm::isa(operandType)) { LLVM::ConstantOp one = LLVM::isCompatibleVectorType(operandType) ? rewriter.create( loc, operandType, - SplatElementsAttr::get(resultType.cast(), + SplatElementsAttr::get(llvm::cast(resultType), floatOne)) : rewriter.create(loc, operandType, floatOne); @@ -202,7 +203,7 @@ return success(); } - auto vectorType = resultType.dyn_cast(); + auto vectorType = llvm::dyn_cast(resultType); if (!vectorType) return rewriter.notifyMatchFailure(op, "expected vector result type"); @@ -240,17 +241,18 @@ auto loc = op.getLoc(); auto resultType = op.getResult().getType(); - auto floatType = getElementTypeOrSelf(resultType).cast(); + auto floatType = llvm::cast(getElementTypeOrSelf(resultType)); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); ConvertFastMath sqrtAttrs(op); ConvertFastMath divAttrs(op); - if (!operandType.isa()) { + if (!llvm::isa(operandType)) { LLVM::ConstantOp one; if (LLVM::isCompatibleVectorType(operandType)) { one = rewriter.create( loc, operandType, - SplatElementsAttr::get(resultType.cast(), floatOne)); + SplatElementsAttr::get(llvm::cast(resultType), + floatOne)); } else { one = rewriter.create(loc, operandType, floatOne); } @@ -261,7 +263,7 @@ return success(); } - auto vectorType = resultType.dyn_cast(); + auto vectorType = llvm::dyn_cast(resultType); if (!vectorType) return failure(); diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -75,7 +75,7 @@ VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const { auto opType = op.getType(); auto loc = op.getLoc(); - auto vecType = opType.template dyn_cast(); + auto vecType = llvm::dyn_cast(opType); if (!vecType) return failure(); @@ -107,7 +107,7 @@ LogicalResult PromoteOpToF32::matchAndRewrite(Op op, PatternRewriter &rewriter) const { auto opType = op.getType(); - if (!opType.template isa()) + if (!llvm::isa(opType)) return failure(); auto loc = op.getLoc(); @@ -127,7 +127,7 @@ PatternRewriter &rewriter) const { auto module = SymbolTable::getNearestSymbolTable(op); auto type = op.getType(); - if (!type.template isa()) + if (!llvm::isa(type)) return failure(); auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp @@ -58,7 +58,8 @@ Location loc, Value allocatedPtr, MemRefType memRefType, Type elementPtrType, LLVMTypeConverter &typeConverter) { - auto allocatedPtrTy = allocatedPtr.getType().cast(); + auto allocatedPtrTy = + llvm::cast(allocatedPtr.getType()); unsigned memrefAddrSpace = *typeConverter.getMemRefAddressSpace(memRefType); if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace) allocatedPtr = rewriter.create( @@ -114,10 +115,10 @@ layout = &analysis->getAbove(op); } Type elementType = memRefType.getElementType(); - if (auto memRefElementType = elementType.dyn_cast()) + if (auto memRefElementType = llvm::dyn_cast(elementType)) return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, *layout); - if (auto memRefElementType = elementType.dyn_cast()) + if (auto memRefElementType = llvm::dyn_cast(elementType)) return getTypeConverter()->getUnrankedMemRefDescriptorSize( memRefElementType, *layout); return layout->getTypeSize(elementType); diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -184,10 +184,10 @@ rewriter.setInsertionPointToEnd(currentBlock); Value src = op.getSource(); - auto srcType = src.getType().dyn_cast(); + auto srcType = llvm::dyn_cast(src.getType()); Value srcNumElements = computeNumElements( srcType, [&]() -> Value { return desc.size(rewriter, loc, 0); }); - auto dstType = op.getType().cast(); + auto dstType = llvm::cast(op.getType()); Value dstNumElements = computeNumElements( dstType, [&]() -> Value { return op.getDynamicResultSize(); }); Value cond = rewriter.create( @@ -342,7 +342,7 @@ unsigned alignment = op.getAlignment(); auto loc = op.getLoc(); - auto srcMemRefType = op.getMemref().getType().cast(); + auto srcMemRefType = llvm::cast(op.getMemref().getType()); Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{}, rewriter); @@ -417,7 +417,7 @@ matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type operandType = dimOp.getSource().getType(); - if (operandType.isa()) { + if (llvm::isa(operandType)) { FailureOr extractedSize = extractSizeOfUnrankedMemRef( operandType, dimOp, adaptor.getOperands(), rewriter); if (failed(extractedSize)) @@ -425,7 +425,7 @@ rewriter.replaceOp(dimOp, {*extractedSize}); return success(); } - if (operandType.isa()) { + if (llvm::isa(operandType)) { rewriter.replaceOp( dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, adaptor.getOperands(), rewriter)}); @@ -441,7 +441,7 @@ ConversionPatternRewriter &rewriter) const { Location loc = dimOp.getLoc(); - auto unrankedMemRefType = operandType.cast(); + auto unrankedMemRefType = llvm::cast(operandType); auto scalarMemRefType = MemRefType::get({}, unrankedMemRefType.getElementType()); FailureOr maybeAddressSpace = @@ -492,8 +492,7 @@ return idx; if (auto constantOp = dimOp.getIndex().getDefiningOp()) - return constantOp.getValue() - .cast() + return llvm::cast(constantOp.getValue()) .getValue() .getSExtValue(); @@ -506,7 +505,7 @@ Location loc = dimOp.getLoc(); // Take advantage if index is constant. - MemRefType memRefType = operandType.cast(); + MemRefType memRefType = llvm::cast(operandType); if (std::optional index = getConstantDimIndex(dimOp)) { int64_t i = *index; if (i >= 0 && i < memRefType.getRank()) { @@ -589,7 +588,7 @@ // Compute the loaded value and branch to the loop block. rewriter.setInsertionPointToEnd(initBlock); - auto memRefType = atomicOp.getMemref().getType().cast(); + auto memRefType = llvm::cast(atomicOp.getMemref().getType()); auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(), adaptor.getIndices(), rewriter); Value init = rewriter.create( @@ -712,7 +711,7 @@ Location loc, Value sizeBytes, Operation *op) const override { auto getGlobalOp = cast(op); - MemRefType type = getGlobalOp.getResult().getType().cast(); + MemRefType type = llvm::cast(getGlobalOp.getResult().getType()); // This is called after a type conversion, which would have failed if this // call fails. @@ -823,12 +822,13 @@ ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type operandType = op.getMemref().getType(); - if (auto unrankedMemRefType = operandType.dyn_cast()) { + if (auto unrankedMemRefType = + llvm::dyn_cast(operandType)) { UnrankedMemRefDescriptor desc(adaptor.getMemref()); rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); return success(); } - if (auto rankedMemRefType = operandType.dyn_cast()) { + if (auto rankedMemRefType = llvm::dyn_cast(operandType)) { rewriter.replaceOp( op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); return success(); @@ -849,17 +849,17 @@ // and require source and result type to have the same rank. Therefore, // perform a sanity check that the underlying structs are the same. Once op // semantics are relaxed we can revisit. - if (srcType.isa() && dstType.isa()) + if (llvm::isa(srcType) && llvm::isa(dstType)) return success(typeConverter->convertType(srcType) == typeConverter->convertType(dstType)); // At least one of the operands is unranked type - assert(srcType.isa() || - dstType.isa()); + assert(llvm::isa(srcType) || + llvm::isa(dstType)); // Unranked to unranked cast is disallowed - return !(srcType.isa() && - dstType.isa()) + return !(llvm::isa(srcType) && + llvm::isa(dstType)) ? success() : failure(); } @@ -872,15 +872,16 @@ auto loc = memRefCastOp.getLoc(); // For ranked/ranked case, just keep the original descriptor. - if (srcType.isa() && dstType.isa()) + if (llvm::isa(srcType) && llvm::isa(dstType)) return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()}); - if (srcType.isa() && dstType.isa()) { + if (llvm::isa(srcType) && + llvm::isa(dstType)) { // Casting ranked to unranked memref type // Set the rank in the destination from the memref type // Allocate space on the stack and copy the src memref descriptor // Set the ptr in the destination to the stack space - auto srcMemRefType = srcType.cast(); + auto srcMemRefType = llvm::cast(srcType); int64_t rank = srcMemRefType.getRank(); // ptr = AllocaOp sizeof(MemRefDescriptor) auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( @@ -905,7 +906,8 @@ memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); - } else if (srcType.isa() && dstType.isa()) { + } else if (llvm::isa(srcType) && + llvm::isa(dstType)) { // Casting from unranked type to ranked. // The operation is assumed to be doing a correct cast. If the destination // type mismatches the unranked the type, it is undefined behavior. @@ -942,7 +944,7 @@ lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); - auto srcType = op.getSource().getType().dyn_cast(); + auto srcType = llvm::dyn_cast(op.getSource().getType()); MemRefDescriptor srcDesc(adaptor.getSource()); @@ -984,8 +986,8 @@ lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); - auto srcType = op.getSource().getType().cast(); - auto targetType = op.getTarget().getType().cast(); + auto srcType = llvm::cast(op.getSource().getType()); + auto targetType = llvm::cast(op.getTarget().getType()); // First make sure we have an unranked memref descriptor representation. auto makeUnranked = [&, this](Value ranked, MemRefType type) { @@ -1012,11 +1014,11 @@ auto stackSaveOp = rewriter.create(loc, getVoidPtrType()); - auto srcMemRefType = srcType.dyn_cast(); + auto srcMemRefType = llvm::dyn_cast(srcType); Value unrankedSource = srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType) : adaptor.getSource(); - auto targetMemRefType = targetType.dyn_cast(); + auto targetMemRefType = llvm::dyn_cast(targetType); Value unrankedTarget = targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType) : adaptor.getTarget(); @@ -1055,8 +1057,8 @@ LogicalResult matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto srcType = op.getSource().getType().cast(); - auto targetType = op.getTarget().getType().cast(); + auto srcType = llvm::cast(op.getSource().getType()); + auto targetType = llvm::cast(op.getTarget().getType()); auto isStaticShapeAndContiguousRowMajor = [](MemRefType type) { if (!type.hasStaticShape()) @@ -1077,7 +1079,7 @@ }; auto isContiguousMemrefType = [&](BaseMemRefType type) { - auto memrefType = type.dyn_cast(); + auto memrefType = llvm::dyn_cast(type); // We can use memcpy for memrefs if they have an identity layout or are // contiguous with an arbitrary offset. Ignore empty memrefs, which is a // special case handled by memrefCopy. @@ -1105,9 +1107,9 @@ Location loc = op.getLoc(); Type resultType = op.getDest().getType(); - if (auto resultTypeR = resultType.dyn_cast()) { - auto resultDescType = - typeConverter->convertType(resultTypeR).cast(); + if (auto resultTypeR = llvm::dyn_cast(resultType)) { + auto resultDescType = llvm::cast( + typeConverter->convertType(resultTypeR)); Type newPtrType = resultDescType.getBody()[0]; SmallVector descVals; @@ -1122,10 +1124,11 @@ rewriter.replaceOp(op, result); return success(); } - if (auto resultTypeU = resultType.dyn_cast()) { + if (auto resultTypeU = llvm::dyn_cast(resultType)) { // Since the type converter won't be doing this for us, get the address // space. - auto sourceType = op.getSource().getType().cast(); + auto sourceType = + llvm::cast(op.getSource().getType()); FailureOr maybeSourceAddrSpace = getTypeConverter()->getMemRefAddressSpace(sourceType); if (failed(maybeSourceAddrSpace)) @@ -1217,7 +1220,7 @@ Value *allocatedPtr, Value *alignedPtr, Value *offset = nullptr) { Type operandType = originalOperand.getType(); - if (operandType.isa()) { + if (llvm::isa(operandType)) { MemRefDescriptor desc(convertedOperand); *allocatedPtr = desc.allocatedPtr(rewriter, loc); *alignedPtr = desc.alignedPtr(rewriter, loc); @@ -1228,8 +1231,9 @@ // These will all cause assert()s on unconvertible types. unsigned memorySpace = *typeConverter.getMemRefAddressSpace( - operandType.cast()); - Type elementType = operandType.cast().getElementType(); + llvm::cast(operandType)); + Type elementType = + llvm::cast(operandType).getElementType(); Type llvmElementType = typeConverter.convertType(elementType); LLVM::LLVMPointerType elementPtrType = typeConverter.getPointerType(llvmElementType, memorySpace); @@ -1273,9 +1277,9 @@ memref::ReinterpretCastOp castOp, memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { MemRefType targetMemRefType = - castOp.getResult().getType().cast(); - auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) - .dyn_cast_or_null(); + llvm::cast(castOp.getResult().getType()); + auto llvmTargetDescriptorTy = llvm::dyn_cast_or_null( + typeConverter->convertType(targetMemRefType)); if (!llvmTargetDescriptorTy) return failure(); @@ -1339,13 +1343,14 @@ Type srcType, memref::ReshapeOp reshapeOp, memref::ReshapeOp::Adaptor adaptor, Value *descriptor) const { - auto shapeMemRefType = reshapeOp.getShape().getType().cast(); + auto shapeMemRefType = + llvm::cast(reshapeOp.getShape().getType()); if (shapeMemRefType.hasStaticShape()) { MemRefType targetMemRefType = - reshapeOp.getResult().getType().cast(); + llvm::cast(reshapeOp.getResult().getType()); auto llvmTargetDescriptorTy = - typeConverter->convertType(targetMemRefType) - .dyn_cast_or_null(); + llvm::dyn_cast_or_null( + typeConverter->convertType(targetMemRefType)); if (!llvmTargetDescriptorTy) return failure(); @@ -1427,7 +1432,7 @@ // Extract address space and element type. auto targetType = - reshapeOp.getResult().getType().cast(); + llvm::cast(reshapeOp.getResult().getType()); unsigned addressSpace = *getTypeConverter()->getMemRefAddressSpace(targetType); Type elementType = targetType.getElementType(); @@ -1695,7 +1700,7 @@ // Field 1: Copy the allocated pointer, used for malloc/free. Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); - auto srcMemRefType = viewOp.getSource().getType().cast(); + auto srcMemRefType = llvm::cast(viewOp.getSource().getType()); unsigned sourceMemorySpace = *getTypeConverter()->getMemRefAddressSpace(srcMemRefType); Value bitcastPtr; @@ -1848,7 +1853,7 @@ Location loc = extractStridedMetadataOp.getLoc(); Value source = extractStridedMetadataOp.getSource(); - auto sourceMemRefType = source.getType().cast(); + auto sourceMemRefType = llvm::cast(source.getType()); int64_t rank = sourceMemRefType.getRank(); SmallVector results; results.reserve(2 + rank * 2); @@ -1858,7 +1863,8 @@ Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc); MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), - extractStridedMetadataOp.getBaseBuffer().getType().cast(), + llvm::cast( + extractStridedMetadataOp.getBaseBuffer().getType()), baseBuffer, alignedBuffer); results.push_back((Value)dstMemRef); diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -93,11 +93,13 @@ /// can be lowered to SPIR-V. static bool isAllocationSupported(Operation *allocOp, MemRefType type) { if (isa(allocOp)) { - auto sc = type.getMemorySpace().dyn_cast_or_null(); + auto sc = + llvm::dyn_cast_or_null(type.getMemorySpace()); if (!sc || sc.getValue() != spirv::StorageClass::Workgroup) return false; } else if (isa(allocOp)) { - auto sc = type.getMemorySpace().dyn_cast_or_null(); + auto sc = + llvm::dyn_cast_or_null(type.getMemorySpace()); if (!sc || sc.getValue() != spirv::StorageClass::Function) return false; } else { @@ -110,7 +112,7 @@ return false; Type elementType = type.getElementType(); - if (auto vecType = elementType.dyn_cast()) + if (auto vecType = llvm::dyn_cast(elementType)) elementType = vecType.getElementType(); return elementType.isIntOrFloat(); } @@ -119,7 +121,8 @@ /// operations of unsupported integer bitwidths, based on the memref /// type. Returns std::nullopt on failure. static std::optional getAtomicOpScope(MemRefType type) { - auto sc = type.getMemorySpace().dyn_cast_or_null(); + auto sc = + llvm::dyn_cast_or_null(type.getMemorySpace()); switch (sc.getValue()) { case spirv::StorageClass::StorageBuffer: return spirv::Scope::Device; @@ -324,11 +327,11 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - if (atomicOp.getType().isa()) + if (llvm::isa(atomicOp.getType())) return rewriter.notifyMatchFailure(atomicOp, "unimplemented floating-point case"); - auto memrefType = atomicOp.getMemref().getType().cast(); + auto memrefType = llvm::cast(atomicOp.getMemref().getType()); std::optional scope = getAtomicOpScope(memrefType); if (!scope) return rewriter.notifyMatchFailure(atomicOp, @@ -380,7 +383,8 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - MemRefType deallocType = operation.getMemref().getType().cast(); + MemRefType deallocType = + llvm::cast(operation.getMemref().getType()); if (!isAllocationSupported(operation, deallocType)) return rewriter.notifyMatchFailure(operation, "unhandled allocation type"); rewriter.eraseOp(operation); @@ -395,7 +399,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = loadOp.getLoc(); - auto memrefType = loadOp.getMemref().getType().cast(); + auto memrefType = llvm::cast(loadOp.getMemref().getType()); if (!memrefType.getElementType().isSignlessInteger()) return failure(); @@ -419,18 +423,19 @@ Type pointeeType = pointerType.getPointeeType(); Type dstType; if (typeConverter.allows(spirv::Capability::Kernel)) { - if (auto arrayType = pointeeType.dyn_cast()) + if (auto arrayType = llvm::dyn_cast(pointeeType)) dstType = arrayType.getElementType(); else dstType = pointeeType; } else { // For Vulkan we need to extract element from wrapping struct and array. Type structElemType = - pointeeType.cast().getElementType(0); - if (auto arrayType = structElemType.dyn_cast()) + llvm::cast(pointeeType).getElementType(0); + if (auto arrayType = llvm::dyn_cast(structElemType)) dstType = arrayType.getElementType(); else - dstType = structElemType.cast().getElementType(); + dstType = + llvm::cast(structElemType).getElementType(); } int dstBits = dstType.getIntOrFloatBitWidth(); assert(dstBits % srcBits == 0); @@ -509,7 +514,7 @@ LogicalResult LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto memrefType = loadOp.getMemref().getType().cast(); + auto memrefType = llvm::cast(loadOp.getMemref().getType()); if (memrefType.getElementType().isSignlessInteger()) return failure(); auto loadPtr = spirv::getElementPtr( @@ -526,7 +531,7 @@ LogicalResult IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto memrefType = storeOp.getMemref().getType().cast(); + auto memrefType = llvm::cast(storeOp.getMemref().getType()); if (!memrefType.getElementType().isSignlessInteger()) return failure(); @@ -553,18 +558,19 @@ Type pointeeType = pointerType.getPointeeType(); Type dstType; if (typeConverter.allows(spirv::Capability::Kernel)) { - if (auto arrayType = pointeeType.dyn_cast()) + if (auto arrayType = llvm::dyn_cast(pointeeType)) dstType = arrayType.getElementType(); else dstType = pointeeType; } else { // For Vulkan we need to extract element from wrapping struct and array. Type structElemType = - pointeeType.cast().getElementType(0); - if (auto arrayType = structElemType.dyn_cast()) + llvm::cast(pointeeType).getElementType(0); + if (auto arrayType = llvm::dyn_cast(structElemType)) dstType = arrayType.getElementType(); else - dstType = structElemType.cast().getElementType(); + dstType = + llvm::cast(structElemType).getElementType(); } int dstBits = dstType.getIntOrFloatBitWidth(); @@ -651,21 +657,22 @@ return rewriter.notifyMatchFailure( loc, "address space casts require kernel capability"); - auto sourceType = addrCastOp.getSource().getType().dyn_cast(); + auto sourceType = + llvm::dyn_cast(addrCastOp.getSource().getType()); if (!sourceType) return rewriter.notifyMatchFailure( loc, "SPIR-V lowering requires ranked memref types"); - auto resultType = addrCastOp.getResult().getType().cast(); + auto resultType = llvm::cast(addrCastOp.getResult().getType()); - auto sourceStorageClassAttr = - sourceType.getMemorySpace().dyn_cast_or_null(); + auto sourceStorageClassAttr = llvm::dyn_cast_or_null( + sourceType.getMemorySpace()); if (!sourceStorageClassAttr) return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) { diag << "source address space " << sourceType.getMemorySpace() << " must be a SPIR-V storage class"; }); - auto resultStorageClassAttr = - resultType.getMemorySpace().dyn_cast_or_null(); + auto resultStorageClassAttr = llvm::dyn_cast_or_null( + resultType.getMemorySpace()); if (!resultStorageClassAttr) return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) { diag << "result address space " << resultType.getMemorySpace() @@ -709,7 +716,7 @@ LogicalResult StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto memrefType = storeOp.getMemref().getType().cast(); + auto memrefType = llvm::cast(storeOp.getMemref().getType()); if (memrefType.getElementType().isSignlessInteger()) return failure(); auto storePtr = spirv::getElementPtr( diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -28,7 +28,7 @@ /// `gpu.mma.sync` operation. static Type inferIntrinsicResultType(Type vectorResultType) { MLIRContext *ctx = vectorResultType.getContext(); - auto a = vectorResultType.cast(); + auto a = llvm::cast(vectorResultType); auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2); auto i32Ty = IntegerType::get(ctx, 32); auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); @@ -69,8 +69,8 @@ Type resultType, Value intrinsicResult, RewriterBase &rewriter) { MLIRContext *ctx = rewriter.getContext(); - auto structType = intrinsicResultType.dyn_cast(); - auto arrayType = resultType.dyn_cast(); + auto structType = llvm::dyn_cast(intrinsicResultType); + auto arrayType = llvm::dyn_cast(resultType); Type i32Ty = rewriter.getI32Type(); Type f32Ty = rewriter.getF32Type(); Type f64Ty = rewriter.getF64Type(); @@ -153,7 +153,7 @@ Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4); Type i4x8Ty = LLVM::getFixedVectorType(i4Ty, 8); Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); - auto arrayTy = operand.getType().cast(); + auto arrayTy = llvm::cast(operand.getType()); for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { Value toUse = rewriter.create(loc, operand, i); @@ -172,7 +172,8 @@ // For some element types (i32, f32, f64), we need to unpack the inner // vector/array type as well because the intrinsic expects individual // scalars to be provided. - VectorType innerArrayTy = arrayTy.getElementType().dyn_cast(); + VectorType innerArrayTy = + llvm::dyn_cast(arrayTy.getElementType()); if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty || innerArrayTy.getElementType() == f64Ty || innerArrayTy.getElementType() == f32Ty)) { @@ -207,7 +208,7 @@ // of shape (NumRegisters, VectorRegister) where VectorRegister is the // vector type of the result and always 32 bits long. We bitcast the result // of the NVVM::LdMatrix to this vector type. - auto vectorResultType = op->getResultTypes()[0].dyn_cast(); + auto vectorResultType = llvm::dyn_cast(op->getResultTypes()[0]); if (!vectorResultType) { return failure(); } @@ -224,7 +225,7 @@ ldMatrixResultType = rewriter.getI32Type(); } - auto srcMemrefType = op.getSrcMemref().getType().cast(); + auto srcMemrefType = llvm::cast(op.getSrcMemref().getType()); Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrcMemref(), adaptor.getIndices(), rewriter); @@ -307,7 +308,7 @@ // TODO: add an attribute to the op to customize this behavior. std::optional overflow(std::nullopt); - if (aType.getElementType().isa()) + if (llvm::isa(aType.getElementType())) overflow = NVVM::MMAIntOverflow::satfinite; SmallVector matA = @@ -388,7 +389,8 @@ // constant. auto dstByteConstOp = dyn_cast(dstBytes.getDefiningOp()); - auto dstByteAttr = dstByteConstOp.getValue().dyn_cast(); + auto dstByteAttr = + llvm::dyn_cast(dstByteConstOp.getValue()); int64_t dstByteVal = dstByteAttr.getValue().getSExtValue(); assert((dstByteVal == 4 || dstByteVal == 8 || dstByteVal == 16) && @@ -537,7 +539,7 @@ // TODO: add an attribute to the op to customize this behavior. std::optional overflow(std::nullopt); - if (aType.getElementType().isa()) + if (llvm::isa(aType.getElementType())) overflow = NVVM::MMAIntOverflow::satfinite; SmallVector matA = @@ -585,7 +587,7 @@ matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - auto dstMemrefType = op.getDst().getType().cast(); + auto dstMemrefType = llvm::cast(op.getDst().getType()); Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(), adaptor.getDstIndices(), rewriter); auto i8Ty = IntegerType::get(op.getContext(), 8); @@ -599,7 +601,7 @@ if (!getTypeConverter()->useOpaquePointers()) dstPtr = rewriter.create(loc, dstPointerType, dstPtr); - auto srcMemrefType = op.getSrc().getType().cast(); + auto srcMemrefType = llvm::cast(op.getSrc().getType()); FailureOr srcAddressSpace = getTypeConverter()->getMemRefAddressSpace(srcMemrefType); if (failed(srcAddressSpace)) diff --git a/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp b/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp --- a/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp +++ b/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp @@ -44,14 +44,15 @@ /// Check whether the type is a valid data descriptor. bool DataDescriptor::isValid(Value descriptor) { - if (auto type = descriptor.getType().dyn_cast()) { + if (auto type = llvm::dyn_cast(descriptor.getType())) { if (type.isIdentified() && type.getName().startswith(getStructName()) && type.getBody().size() == 3 && - (type.getBody()[kPtrBasePosInDataDescriptor] - .isa() || - type.getBody()[kPtrBasePosInDataDescriptor] - .isa()) && - type.getBody()[kPtrPosInDataDescriptor].isa() && + (llvm::isa( + type.getBody()[kPtrBasePosInDataDescriptor]) || + llvm::isa( + type.getBody()[kPtrBasePosInDataDescriptor])) && + llvm::isa( + type.getBody()[kPtrPosInDataDescriptor]) && type.getBody()[kSizePosInDataDescriptor].isInteger(64)) return true; } @@ -104,7 +105,7 @@ // Traverse operands that were converted to MemRefDescriptors. if (auto memRefType = - originalDataOperand.getType().dyn_cast()) { + llvm::dyn_cast(originalDataOperand.getType())) { Type structType = converter->convertType(memRefType); Value memRefDescriptor = builder .create( @@ -127,7 +128,8 @@ descr.setPointer(builder, loc, dataPtr); descr.setSize(builder, loc, sizeBytes); convertedOperands.push_back(descr); - } else if (originalDataOperand.getType().isa()) { + } else if (llvm::isa( + originalDataOperand.getType())) { convertedOperands.push_back(originalDataOperand); } else { // Type not supported. @@ -185,7 +187,7 @@ auto allDataOperandsAreConverted = [](ValueRange operands) { for (Value operand : operands) { if (!DataDescriptor::isValid(operand) && - !operand.getType().isa()) + !llvm::isa(operand.getType())) return false; } return true; diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -70,7 +70,7 @@ Value originalVariableOperand = curOp.getVariableOperand(idx); if (!originalVariableOperand) return failure(); - if (originalVariableOperand.getType().isa()) { + if (llvm::isa(originalVariableOperand.getType())) { // TODO: Support memref type in variable operands return rewriter.notifyMatchFailure(curOp, "memref is not supported yet"); @@ -101,7 +101,7 @@ Value originalVariableOperand = curOp.getVariableOperand(idx); if (!originalVariableOperand) return failure(); - if (originalVariableOperand.getType().isa()) { + if (llvm::isa(originalVariableOperand.getType())) { // TODO: Support memref type in variable operands return rewriter.notifyMatchFailure(curOp, "memref is not supported yet"); @@ -143,7 +143,7 @@ LogicalResult matchAndRewrite(omp::ReductionOp curOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (curOp.getAccumulator().getType().isa()) { + if (llvm::isa(curOp.getAccumulator().getType())) { // TODO: Support memref type in variable operands return rewriter.notifyMatchFailure(curOp, "memref is not supported yet"); } diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -441,7 +441,7 @@ Value iv, lowerBound, upperBound, step; std::tie(mappingAttribute, iv, lowerBound, upperBound, step) = config; auto annotation = - mappingAttribute.dyn_cast(); + llvm::dyn_cast(mappingAttribute); if (!annotation) return parallelOp.emitOpError() << "expected mapping attribute for lowering to GPU"; diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -51,7 +51,7 @@ Value reducedVal = matchReduction({block.getArguments()[1]}, /*redPos=*/0, combinerOps); - if (!reducedVal || !reducedVal.isa() || + if (!reducedVal || !llvm::isa(reducedVal) || combinerOps.size() != 1) return false; @@ -155,7 +155,7 @@ /// Returns an attribute with the minimum (if `min` is set) or the maximum value /// (otherwise) for the given float type. static Attribute minMaxValueForFloat(Type type, bool min) { - auto fltType = type.cast(); + auto fltType = llvm::cast(type); return FloatAttr::get( type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min)); } @@ -164,7 +164,7 @@ /// the maximum value (otherwise) for the given integer type, regardless of its /// signedness semantics (only the width is considered). static Attribute minMaxValueForSignedInt(Type type, bool min) { - auto intType = type.cast(); + auto intType = llvm::cast(type); unsigned bitwidth = intType.getWidth(); return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth) : llvm::APInt::getSignedMaxValue(bitwidth)); @@ -174,7 +174,7 @@ /// the maximum value (otherwise) for the given integer type, regardless of its /// signedness semantics (only the width is considered). static Attribute minMaxValueForUnsignedInt(Type type, bool min) { - auto intType = type.cast(); + auto intType = llvm::cast(type); unsigned bitwidth = intType.getWidth(); return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth) : llvm::APInt::getAllOnes(bitwidth)); @@ -388,7 +388,7 @@ reductionVariables.reserve(parallelOp.getNumReductions()); for (Value init : parallelOp.getInitVals()) { assert((LLVM::isCompatibleType(init.getType()) || - init.getType().isa()) && + llvm::isa(init.getType())) && "cannot create a reduction variable if the type is not an LLVM " "pointer element"); Value storage = rewriter.create( diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -220,9 +220,8 @@ auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands); for (const auto &operand : llvm::enumerate(kernelOperands)) { // Check if the kernel's operand is a ranked memref. - auto memRefType = launchOp.getKernelOperand(operand.index()) - .getType() - .dyn_cast(); + auto memRefType = llvm::dyn_cast( + launchOp.getKernelOperand(operand.index()).getType()); if (!memRefType) return failure(); @@ -240,8 +239,8 @@ // the kernel operand. Construct its new name and create a corresponding // LLVM dialect global variable. spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()]; - auto pointeeType = - spirvGlobal.getType().cast().getPointeeType(); + auto pointeeType = llvm::cast(spirvGlobal.getType()) + .getPointeeType(); auto dstGlobalType = typeConverter->convertType(pointeeType); if (!dstGlobalType) return failure(); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -37,7 +37,7 @@ static bool isSignedIntegerOrVector(Type type) { if (type.isSignedInteger()) return true; - if (auto vecType = type.dyn_cast()) + if (auto vecType = llvm::dyn_cast(type)) return vecType.getElementType().isSignedInteger(); return false; } @@ -46,18 +46,18 @@ static bool isUnsignedIntegerOrVector(Type type) { if (type.isUnsignedInteger()) return true; - if (auto vecType = type.dyn_cast()) + if (auto vecType = llvm::dyn_cast(type)) return vecType.getElementType().isUnsignedInteger(); return false; } /// Returns the bit width of integer, float or vector of float or integer values static unsigned getBitWidth(Type type) { - assert((type.isIntOrFloat() || type.isa()) && + assert((type.isIntOrFloat() || llvm::isa(type)) && "bitwidth is not supported for this type"); if (type.isIntOrFloat()) return type.getIntOrFloatBitWidth(); - auto vecType = type.dyn_cast(); + auto vecType = llvm::dyn_cast(type); auto elementType = vecType.getElementType(); assert(elementType.isIntOrFloat() && "only integers and floats have a bitwidth"); @@ -66,29 +66,29 @@ /// Returns the bit width of LLVMType integer or vector. static unsigned getLLVMTypeBitWidth(Type type) { - return (LLVM::isCompatibleVectorType(type) ? LLVM::getVectorElementType(type) - : type) - .cast() + return llvm::cast((LLVM::isCompatibleVectorType(type) + ? LLVM::getVectorElementType(type) + : type)) .getWidth(); } /// Creates `IntegerAttribute` with all bits set for given type static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { - if (auto vecType = type.dyn_cast()) { - auto integerType = vecType.getElementType().cast(); + if (auto vecType = llvm::dyn_cast(type)) { + auto integerType = llvm::cast(vecType.getElementType()); return builder.getIntegerAttr(integerType, -1); } - auto integerType = type.cast(); + auto integerType = llvm::cast(type); return builder.getIntegerAttr(integerType, -1); } /// Creates `llvm.mlir.constant` with all bits set for the given type. static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter) { - if (srcType.isa()) { + if (llvm::isa(srcType)) { return rewriter.create( loc, dstType, - SplatElementsAttr::get(srcType.cast(), + SplatElementsAttr::get(llvm::cast(srcType), minusOneIntegerAttribute(srcType, rewriter))); } return rewriter.create( @@ -98,14 +98,14 @@ /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value. static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value) { - if (auto vecType = srcType.dyn_cast()) { - auto floatType = vecType.getElementType().cast(); + if (auto vecType = llvm::dyn_cast(srcType)) { + auto floatType = llvm::cast(vecType.getElementType()); return rewriter.create( loc, dstType, SplatElementsAttr::get(vecType, rewriter.getFloatAttr(floatType, value))); } - auto floatType = srcType.cast(); + auto floatType = llvm::cast(srcType); return rewriter.create( loc, dstType, rewriter.getFloatAttr(floatType, value)); } @@ -157,7 +157,7 @@ static Value optionallyBroadcast(Location loc, Value value, Type srcType, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { - if (auto vectorType = srcType.dyn_cast()) { + if (auto vectorType = llvm::dyn_cast(srcType)) { unsigned numElements = vectorType.getNumElements(); return broadcast(loc, value, numElements, typeConverter, rewriter); } @@ -251,7 +251,7 @@ TypeConverter &converter) { unsigned stride = type.getArrayStride(); Type elementType = type.getElementType(); - auto sizeInBytes = elementType.cast().getSizeInBytes(); + auto sizeInBytes = llvm::cast(elementType).getSizeInBytes(); if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride)) return std::nullopt; @@ -319,10 +319,9 @@ indices.insert(indices.begin(), zero); rewriter.replaceOpWithNewOp( op, dstType, - typeConverter.convertType(op.getBasePtr() - .getType() - .cast() - .getPointeeType()), + typeConverter.convertType( + llvm::cast(op.getBasePtr().getType()) + .getPointeeType()), adaptor.getBasePtr(), indices); return success(); } @@ -397,7 +396,7 @@ matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = constOp.getType(); - if (!srcType.isa() && !srcType.isIntOrFloat()) + if (!llvm::isa(srcType) && !srcType.isIntOrFloat()) return failure(); auto dstType = typeConverter.convertType(srcType); @@ -413,15 +412,16 @@ isUnsignedIntegerOrVector(srcType)) { auto signlessType = rewriter.getIntegerType(getBitWidth(srcType)); - if (srcType.isa()) { - auto dstElementsAttr = constOp.getValue().cast(); + if (llvm::isa(srcType)) { + auto dstElementsAttr = + llvm::cast(constOp.getValue()); rewriter.replaceOpWithNewOp( constOp, dstType, dstElementsAttr.mapValues( signlessType, [&](const APInt &value) { return value; })); return success(); } - auto srcAttr = constOp.getValue().cast(); + auto srcAttr = llvm::cast(constOp.getValue()); auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue()); rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); @@ -454,17 +454,18 @@ // Create a constant that holds the size of the `Base`. IntegerType integerType; - if (auto vecType = srcType.dyn_cast()) - integerType = vecType.getElementType().cast(); + if (auto vecType = llvm::dyn_cast(srcType)) + integerType = llvm::cast(vecType.getElementType()); else - integerType = srcType.cast(); + integerType = llvm::cast(srcType); auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType)); Value size = - srcType.isa() + llvm::isa(srcType) ? rewriter.create( loc, dstType, - SplatElementsAttr::get(srcType.cast(), baseSize)) + SplatElementsAttr::get(llvm::cast(srcType), + baseSize)) : rewriter.create(loc, dstType, baseSize); // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit @@ -573,9 +574,9 @@ return failure(); Type containerType = op.getComposite().getType(); - if (containerType.isa()) { + if (llvm::isa(containerType)) { Location loc = op.getLoc(); - IntegerAttr value = op.getIndices()[0].cast(); + IntegerAttr value = llvm::cast(op.getIndices()[0]); Value index = createI32ConstantOf(loc, rewriter, value.getInt()); rewriter.replaceOpWithNewOp( op, dstType, adaptor.getComposite(), index); @@ -605,9 +606,9 @@ return failure(); Type containerType = op.getComposite().getType(); - if (containerType.isa()) { + if (llvm::isa(containerType)) { Location loc = op.getLoc(); - IntegerAttr value = op.getIndices()[0].cast(); + IntegerAttr value = llvm::cast(op.getIndices()[0]); Value index = createI32ConstantOf(loc, rewriter, value.getInt()); rewriter.replaceOpWithNewOp( op, dstType, adaptor.getComposite(), adaptor.getObject(), index); @@ -732,7 +733,7 @@ if (op.getInitializer()) return failure(); - auto srcType = op.getType().cast(); + auto srcType = llvm::cast(op.getType()); auto dstType = typeConverter.convertType(srcType.getPointeeType()); if (!dstType) return failure(); @@ -946,11 +947,11 @@ Location loc = notOp.getLoc(); IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter); - auto mask = srcType.template isa() + auto mask = llvm::isa(srcType) ? rewriter.create( loc, dstType, SplatElementsAttr::get( - srcType.template cast(), minusOne)) + llvm::cast(srcType), minusOne)) : rewriter.create(loc, dstType, minusOne); rewriter.template replaceOpWithNewOp(notOp, dstType, notOp.getOperand(), mask); @@ -1262,9 +1263,9 @@ ConversionPatternRewriter &rewriter) const override { auto srcType = varOp.getType(); // Initialization is supported for scalars and vectors only. - auto pointerTo = srcType.cast().getPointeeType(); + auto pointerTo = llvm::cast(srcType).getPointeeType(); auto init = varOp.getInitializer(); - if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa()) + if (init && !pointerTo.isIntOrFloat() && !llvm::isa(pointerTo)) return failure(); auto dstType = typeConverter.convertType(srcType); @@ -1303,7 +1304,7 @@ return failure(); if (typeConverter.useOpaquePointers() && - dstType.isa()) { + llvm::isa(dstType)) { rewriter.replaceOp(bitcastOp, adaptor.getOperand()); return success(); } @@ -1416,8 +1417,10 @@ auto components = adaptor.getComponents(); auto vector1 = adaptor.getVector1(); auto vector2 = adaptor.getVector2(); - int vector1Size = vector1.getType().cast().getNumElements(); - int vector2Size = vector2.getType().cast().getNumElements(); + int vector1Size = + llvm::cast(vector1.getType()).getNumElements(); + int vector2Size = + llvm::cast(vector2.getType()).getNumElements(); if (vector1Size == vector2Size) { rewriter.replaceOpWithNewOp( op, vector1, vector2, @@ -1426,16 +1429,16 @@ } auto dstType = typeConverter.convertType(op.getType()); - auto scalarType = dstType.cast().getElementType(); + auto scalarType = llvm::cast(dstType).getElementType(); auto componentsArray = components.getValue(); auto *context = rewriter.getContext(); auto llvmI32Type = IntegerType::get(context, 32); Value targetOp = rewriter.create(loc, dstType); for (unsigned i = 0; i < componentsArray.size(); i++) { - if (!componentsArray[i].isa()) + if (!llvm::isa(componentsArray[i])) return op.emitError("unable to support non-constant component"); - int indexVal = componentsArray[i].cast().getInt(); + int indexVal = llvm::cast(componentsArray[i]).getInt(); if (indexVal == -1) continue; diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -123,7 +123,7 @@ // constant stride. static std::optional getMemrefConstantHorizontalStride(ShapedType type) { - auto memrefType = type.dyn_cast(); + auto memrefType = llvm::dyn_cast(type); if (!memrefType) return false; // If the memref is 0 or 1D the horizontal stride is 0. @@ -193,10 +193,10 @@ /// Return true if the constant is a splat to a 2D vector so that it can be /// converted to a MMA constant matrix op. static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { - auto vecType = constantOp.getType().dyn_cast(); + auto vecType = llvm::dyn_cast(constantOp.getType()); if (!vecType || vecType.getRank() != 2) return false; - return constantOp.getValue().isa(); + return llvm::isa(constantOp.getValue()); } /// Return true if this is a broadcast from scalar to a 2D vector. @@ -268,11 +268,11 @@ // matrixB and matrixC operands. vector.extract_strided_slice op // is not supported on registers containing matrixA operands. if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B) - return (op->getResult(0).getType().cast() == - (*contractOp).getRhs().getType().cast()); + return (llvm::cast(op->getResult(0).getType()) == + llvm::cast((*contractOp).getRhs().getType())); if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C) - return (op->getResult(0).getType().cast() == - (*contractOp).getAcc().getType().cast()); + return (llvm::cast(op->getResult(0).getType()) == + llvm::cast((*contractOp).getAcc().getType())); return false; } @@ -344,11 +344,11 @@ bool useNvGpu) { auto hasVectorDest = [](Operation *op) { return llvm::any_of(op->getResultTypes(), - [](Type t) { return t.isa(); }); + [](Type t) { return llvm::isa(t); }); }; auto hasVectorSrc = [](Operation *op) { return llvm::any_of(op->getOperandTypes(), - [](Type t) { return t.isa(); }); + [](Type t) { return llvm::isa(t); }); }; SetVector opToConvert; op->walk([&](vector::ContractionOp contract) { @@ -447,9 +447,9 @@ if ((extOp = source.getDefiningOp()) || (extOp = source.getDefiningOp())) { source = extOp->getOperand(0); - resultType = - VectorType::get(resultType.cast().getShape(), - source.getType().cast().getElementType()); + resultType = VectorType::get( + llvm::cast(resultType).getShape(), + llvm::cast(source.getType()).getElementType()); } auto transferReadOp = source.getDefiningOp(); @@ -553,7 +553,7 @@ bool isSignedExtend = isa(user); if (isSignedExtend || isa(user)) { elType = IntegerType::get( - op.getContext(), elType.cast().getWidth(), + op.getContext(), llvm::cast(elType).getWidth(), isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned); mappingResult = user->getResult(0); fragType = inferFragType(user); @@ -610,7 +610,7 @@ SmallVector shape{regInfo.numRegistersPerFragment, regInfo.elementsPerRegister}; Type elType = regInfo.registerLLVMType; - if (auto vecType = elType.dyn_cast()) + if (auto vecType = llvm::dyn_cast(elType)) elType = vecType.getElementType(); return VectorType::get(shape, elType); } @@ -637,7 +637,7 @@ } VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); - auto dense = op.getValue().dyn_cast(); + auto dense = llvm::dyn_cast(op.getValue()); if (!dense) { LLVM_DEBUG(DBGS() << "not a splat\n"); return rewriter.notifyMatchFailure(op, "not a splat"); @@ -782,7 +782,7 @@ // If we are not transposing, then we can use vectorized loads. Otherwise, we // must load each element individually. if (!isTransposeLoad) { - if (!loadedElType.isa()) { + if (!llvm::isa(loadedElType)) { loadedElType = VectorType::get({1}, loadedElType); } @@ -805,7 +805,7 @@ rewriter.getI64ArrayAttr(i)); } } else { - if (auto vecType = loadedElType.dyn_cast()) { + if (auto vecType = llvm::dyn_cast(loadedElType)) { loadedElType = vecType.getElementType(); } for (int i = 0; i < vectorType.getShape()[0]; i++) { @@ -838,7 +838,7 @@ /// Return true if this is a shared memory memref type. static bool isSharedMemory(MemRefType type) { auto addressSpace = - type.getMemorySpace().dyn_cast_or_null(); + llvm::dyn_cast_or_null(type.getMemorySpace()); if (addressSpace && addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace()) return true; @@ -860,7 +860,7 @@ return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); bool isLdMatrixCompatible = - isSharedMemory(op.getSource().getType().cast()) && + isSharedMemory(llvm::cast(op.getSource().getType())) && nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128; VectorType vecTy = op.getVectorType(); @@ -929,7 +929,7 @@ static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl &results) { for (auto attr : arrayAttr) - results.push_back(attr.cast().getInt()); + results.push_back(llvm::cast(attr).getInt()); } static LogicalResult @@ -1041,9 +1041,9 @@ itC == valueMapping.end()) return rewriter.notifyMatchFailure(op, "no mapping"); Value opA = itA->second, opB = itB->second, opC = itC->second; - int64_t m = op.getLhs().getType().cast().getShape()[0]; - int64_t n = op.getRhs().getType().cast().getShape()[0]; - int64_t k = op.getLhs().getType().cast().getShape()[1]; + int64_t m = llvm::cast(op.getLhs().getType()).getShape()[0]; + int64_t n = llvm::cast(op.getRhs().getType()).getShape()[0]; + int64_t k = llvm::cast(op.getLhs().getType()).getShape()[1]; Value matmul = rewriter.create( op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k})); valueMapping[op.getResult()] = matmul; @@ -1060,11 +1060,11 @@ assert(constantSupportsMMAMatrixType(op)); auto splat = - op.getValue().cast().getSplatValue(); + llvm::cast(op.getValue()).getSplatValue(); auto scalarConstant = rewriter.create(op.getLoc(), splat.getType(), splat); const char *fragType = inferFragType(op); - auto vecType = op.getType().cast(); + auto vecType = llvm::cast(op.getType()); gpu::MMAMatrixType type = gpu::MMAMatrixType::get( vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); auto matrix = rewriter.create( diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -256,8 +256,8 @@ return failure(); // Resolve address. - auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) - .template cast(); + auto vtype = llvm::cast( + this->typeConverter->convertType(loadOrStoreOp.getVectorType())); Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), adaptor.getIndices(), rewriter); Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype, @@ -277,7 +277,7 @@ LogicalResult matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - MemRefType memRefType = gather.getBaseType().dyn_cast(); + MemRefType memRefType = llvm::dyn_cast(gather.getBaseType()); assert(memRefType && "The base should be bufferized"); if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) @@ -296,7 +296,7 @@ auto llvmNDVectorTy = adaptor.getIndexVec().getType(); // Handle the simple case of 1-D vector. - if (!llvmNDVectorTy.isa()) { + if (!llvm::isa(llvmNDVectorTy)) { auto vType = gather.getVectorType(); // Resolve address. Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), @@ -501,7 +501,7 @@ static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - auto floatType = llvmType.cast(); + auto floatType = llvm::cast(llvmType); return rewriter.create( loc, llvmType, rewriter.getFloatAttr( @@ -513,7 +513,7 @@ static Value createReductionNeutralValue(ReductionNeutralFPMax neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - auto floatType = llvmType.cast(); + auto floatType = llvm::cast(llvmType); return rewriter.create( loc, llvmType, rewriter.getFloatAttr( @@ -585,9 +585,9 @@ /// with vector types. static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, Value rhs, bool isMin) { - auto floatType = getElementTypeOrSelf(lhs.getType()).cast(); + auto floatType = llvm::cast(getElementTypeOrSelf(lhs.getType())); Type i1Type = builder.getI1Type(); - if (auto vecType = lhs.getType().dyn_cast()) + if (auto vecType = llvm::dyn_cast(lhs.getType())) i1Type = VectorType::get(vecType.getShape(), i1Type); Value cmp = builder.create( loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, @@ -768,7 +768,7 @@ return success(); } - if (!eltType.isa()) + if (!llvm::isa(eltType)) return failure(); // Floating-point reductions: add/mul/min/max @@ -966,14 +966,14 @@ // For all other cases, insert the individual values individually. int64_t v1Dim = v1Type.getDimSize(0); Type eltType; - if (auto arrayType = llvmType.dyn_cast()) + if (auto arrayType = llvm::dyn_cast(llvmType)) eltType = arrayType.getElementType(); else - eltType = llvmType.cast().getElementType(); + eltType = llvm::cast(llvmType).getElementType(); Value insert = rewriter.create(loc, llvmType); int64_t insPos = 0; for (const auto &en : llvm::enumerate(maskArrayAttr)) { - int64_t extPos = en.value().cast().getInt(); + int64_t extPos = llvm::cast(en.value()).getInt(); Value value = adaptor.getV1(); if (extPos >= v1Dim) { extPos -= v1Dim; @@ -1046,7 +1046,7 @@ } // One-shot extraction of vector from array (only requires extractvalue). - if (resultType.isa()) { + if (llvm::isa(resultType)) { SmallVector indices; for (auto idx : positionArrayAttr.getAsRange()) indices.push_back(idx.getInt()); @@ -1062,13 +1062,13 @@ if (positionAttrs.size() > 1) { SmallVector nMinusOnePosition; for (auto idx : positionAttrs.drop_back()) - nMinusOnePosition.push_back(idx.cast().getInt()); + nMinusOnePosition.push_back(llvm::cast(idx).getInt()); extracted = rewriter.create(loc, extracted, nMinusOnePosition); } // Remaining extraction of element from 1-D LLVM vector - auto position = positionAttrs.back().cast(); + auto position = llvm::cast(positionAttrs.back()); auto i64Type = IntegerType::get(rewriter.getContext(), 64); auto constant = rewriter.create(loc, i64Type, position); extracted = @@ -1169,7 +1169,7 @@ } // One-shot insertion of a vector into an array (only requires insertvalue). - if (sourceType.isa()) { + if (llvm::isa(sourceType)) { Value inserted = rewriter.create( loc, adaptor.getDest(), adaptor.getSource(), LLVM::convertArrayToIndices(positionArrayAttr)); @@ -1180,7 +1180,7 @@ // Potential extraction of 1-D vector from array. Value extracted = adaptor.getDest(); auto positionAttrs = positionArrayAttr.getValue(); - auto position = positionAttrs.back().cast(); + auto position = llvm::cast(positionAttrs.back()); auto oneDVectorType = destVectorType; if (positionAttrs.size() > 1) { oneDVectorType = reducedVectorTypeBack(destVectorType); @@ -1333,7 +1333,7 @@ ConversionPatternRewriter &rewriter) const override { auto loc = castOp->getLoc(); MemRefType sourceMemRefType = - castOp.getOperand().getType().cast(); + llvm::cast(castOp.getOperand().getType()); MemRefType targetMemRefType = castOp.getType(); // Only static shape casts supported atm. @@ -1341,14 +1341,14 @@ !targetMemRefType.hasStaticShape()) return failure(); - auto llvmSourceDescriptorTy = - adaptor.getOperands()[0].getType().dyn_cast(); + auto llvmSourceDescriptorTy = llvm::dyn_cast( + adaptor.getOperands()[0].getType()); if (!llvmSourceDescriptorTy) return failure(); MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); - auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) - .dyn_cast_or_null(); + auto llvmTargetDescriptorTy = llvm::dyn_cast_or_null( + typeConverter->convertType(targetMemRefType)); if (!llvmTargetDescriptorTy) return failure(); @@ -1418,7 +1418,7 @@ LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override { auto dstType = op.getType(); - if (dstType.getRank() != 1 || !dstType.cast().isScalable()) + if (dstType.getRank() != 1 || !llvm::cast(dstType).isScalable()) return failure(); IntegerType idxType = force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); @@ -1465,7 +1465,7 @@ // Make sure element type has runtime support. PrintConversion conversion = PrintConversion::None; - VectorType vectorType = printType.dyn_cast(); + VectorType vectorType = llvm::dyn_cast(printType); Type eltType = vectorType ? vectorType.getElementType() : printType; auto parent = printOp->getParentOfType(); Operation *printer; @@ -1481,7 +1481,7 @@ printer = LLVM::lookupOrCreatePrintBF16Fn(parent); } else if (eltType.isIndex()) { printer = LLVM::lookupOrCreatePrintU64Fn(parent); - } else if (auto intTy = eltType.dyn_cast()) { + } else if (auto intTy = llvm::dyn_cast(eltType)) { // Integers need a zero or sign extension on the operand // (depending on the source type) as well as a signed or // unsigned print method. Up to 64-bit is supported. @@ -1536,7 +1536,7 @@ void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, Value value, Type type, Operation *printer, int64_t rank, PrintConversion conversion) const { - VectorType vectorType = type.dyn_cast(); + VectorType vectorType = llvm::dyn_cast(type); Location loc = op->getLoc(); if (!vectorType) { assert(rank == 0 && "The scalar case expects rank == 0"); @@ -1610,7 +1610,7 @@ LogicalResult matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType resultType = splatOp.getType().cast(); + VectorType resultType = llvm::cast(splatOp.getType()); if (resultType.getRank() > 1) return failure(); @@ -1633,7 +1633,7 @@ auto v = rewriter.create( splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); - int64_t width = splatOp.getType().cast().getDimSize(0); + int64_t width = llvm::cast(splatOp.getType()).getDimSize(0); SmallVector zeroValues(width, 0); // Shuffle the value across the desired number of elements. diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -258,7 +258,7 @@ /// Return true if this transfer op operates on a source tensor. template static bool isTensorOp(OpTy xferOp) { - if (xferOp.getShapedType().template isa()) { + if (llvm::isa(xferOp.getShapedType())) { if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) { // TransferWriteOps on tensors have a result. assert(xferOp->getNumResults() > 0); @@ -314,7 +314,7 @@ /// /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> static MemRefType unpackOneDim(MemRefType type) { - auto vectorType = type.getElementType().dyn_cast(); + auto vectorType = llvm::dyn_cast(type.getElementType()); auto memrefShape = type.getShape(); SmallVector newMemrefShape; newMemrefShape.append(memrefShape.begin(), memrefShape.end()); @@ -408,8 +408,8 @@ getXferIndices(b, xferOp, iv, xferIndices); Location loc = xferOp.getLoc(); - auto bufferType = buffer.getType().dyn_cast(); - auto vecType = bufferType.getElementType().dyn_cast(); + auto bufferType = llvm::dyn_cast(buffer.getType()); + auto vecType = llvm::dyn_cast(bufferType.getElementType()); auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); auto newXferOp = b.create( loc, vecType, xferOp.getSource(), xferIndices, @@ -432,8 +432,8 @@ storeIndices.push_back(iv); Location loc = xferOp.getLoc(); - auto bufferType = buffer.getType().dyn_cast(); - auto vecType = bufferType.getElementType().dyn_cast(); + auto bufferType = llvm::dyn_cast(buffer.getType()); + auto vecType = llvm::dyn_cast(bufferType.getElementType()); auto vec = b.create(loc, vecType, xferOp.getPadding()); b.create(loc, vec, buffer, storeIndices); @@ -698,7 +698,7 @@ // Find and cast data buffer. How the buffer can be found depends on OpTy. ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter); auto dataBuffer = Strategy::getBuffer(xferOp); - auto dataBufferType = dataBuffer.getType().template dyn_cast(); + auto dataBufferType = llvm::dyn_cast(dataBuffer.getType()); auto castedDataType = unpackOneDim(dataBufferType); auto castedDataBuffer = locB.create(castedDataType, dataBuffer); @@ -707,8 +707,7 @@ Value castedMaskBuffer; if (xferOp.getMask()) { auto maskBuffer = getMaskBuffer(xferOp); - auto maskBufferType = - maskBuffer.getType().template dyn_cast(); + auto maskBufferType = llvm::dyn_cast(maskBuffer.getType()); if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) { // Do not unpack a dimension of the mask, if: // * To-be-unpacked transfer op dimension is a broadcast. @@ -889,7 +888,7 @@ SmallVector &indices) const { if (auto insertOp = getInsertOp(xferOp)) { for (Attribute attr : insertOp.getPosition()) - indices.push_back(attr.dyn_cast().getInt()); + indices.push_back(llvm::dyn_cast(attr).getInt()); } } @@ -908,7 +907,7 @@ auto insertOp = getInsertOp(xferOp); auto vec = getResultVector(xferOp, rewriter); - auto vecType = vec.getType().dyn_cast(); + auto vecType = llvm::dyn_cast(vec.getType()); auto xferVecType = xferOp.getVectorType(); auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(), xferVecType.getElementType()); @@ -1016,7 +1015,7 @@ SmallVector &indices) const { if (auto extractOp = getExtractOp(xferOp)) { for (Attribute attr : extractOp.getPosition()) - indices.push_back(attr.dyn_cast().getInt()); + indices.push_back(llvm::dyn_cast(attr).getInt()); } } @@ -1235,7 +1234,7 @@ if (xferOp.getTransferRank() == 0) return failure(); auto map = xferOp.getPermutationMap(); - auto memRefType = xferOp.getShapedType().template dyn_cast(); + auto memRefType = llvm::dyn_cast(xferOp.getShapedType()); if (!memRefType) return failure(); diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -48,14 +48,15 @@ //===----------------------------------------------------------------------===// template static LogicalResult verifyRawBufferOp(T &op) { - MemRefType bufferType = op.getMemref().getType().template cast(); + MemRefType bufferType = llvm::cast(op.getMemref().getType()); Attribute memorySpace = bufferType.getMemorySpace(); bool isGlobal = false; if (!memorySpace) isGlobal = true; - else if (auto intMemorySpace = memorySpace.dyn_cast()) + else if (auto intMemorySpace = llvm::dyn_cast(memorySpace)) isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1; - else if (auto gpuMemorySpace = memorySpace.dyn_cast()) + else if (auto gpuMemorySpace = + llvm::dyn_cast(memorySpace)) isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global; if (!isGlobal) @@ -216,11 +217,11 @@ Type sourceElem = sourceType, destElem = destType; uint32_t sourceLen = 1, destLen = 1; - if (auto sourceVector = sourceType.dyn_cast()) { + if (auto sourceVector = llvm::dyn_cast(sourceType)) { sourceLen = sourceVector.getNumElements(); sourceElem = sourceVector.getElementType(); } - if (auto destVector = destType.dyn_cast()) { + if (auto destVector = llvm::dyn_cast(destType)) { destLen = destVector.getNumElements(); destElem = destVector.getElementType(); } @@ -229,7 +230,7 @@ if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) { int64_t sourceBLen = 1; Type sourceBElem = sourceBType; - if (auto sourceBVector = sourceBType.dyn_cast()) { + if (auto sourceBVector = llvm::dyn_cast(sourceBType)) { sourceBLen = sourceBVector.getNumElements(); sourceBElem = sourceBVector.getElementType(); } diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -38,7 +38,7 @@ /// top level of a `AffineScope` region is always a valid symbol for all /// uses in that region. bool mlir::affine::isTopLevelValue(Value value, Region *region) { - if (auto arg = value.dyn_cast()) + if (auto arg = llvm::dyn_cast(value)) return arg.getParentRegion() == region; return value.getDefiningOp()->getParentRegion() == region; } @@ -62,7 +62,7 @@ // If it's a top-level value because it's a block operand, i.e. a // function argument, check whether the value replacing it after // inlining is a valid dimension in the new region. - if (value.isa()) + if (llvm::isa(value)) return legalityCheck(mapping.lookup(value), dest); // If it's a top-level value because it's defined in the region, @@ -234,7 +234,7 @@ /// conservatively assume it is not top-level. A value of index type defined at /// the top level is always a valid symbol. bool mlir::affine::isTopLevelValue(Value value) { - if (auto arg = value.dyn_cast()) { + if (auto arg = llvm::dyn_cast(value)) { // The block owning the argument may be unlinked, e.g. when the surrounding // region has not yet been attached to an Op, at which point the parent Op // is null. @@ -273,7 +273,7 @@ // This value has to be a block argument for an op that has the // `AffineScope` trait or for an affine.for or affine.parallel. - auto *parentOp = value.cast().getOwner()->getParentOp(); + auto *parentOp = llvm::cast(value).getOwner()->getParentOp(); return parentOp && (parentOp->hasTrait() || isa(parentOp)); } @@ -296,7 +296,7 @@ if (!op) { // This value has to be a block argument for an affine.for or an // affine.parallel. - auto *parentOp = value.cast().getOwner()->getParentOp(); + auto *parentOp = llvm::cast(value).getOwner()->getParentOp(); return isa(parentOp); } @@ -334,7 +334,7 @@ // Conservatively handle remaining BlockArguments as non-valid symbols. // E.g. scf.for iterArgs. - if (dimOp.getShapedValue().template isa()) + if (llvm::isa(dimOp.getShapedValue())) return false; // The dim op is also okay if its operand memref is a view/subview whose @@ -1221,7 +1221,8 @@ // AffineDialect materializer will create invalid `arith.constant` // operations if the provided Attribute is any other kind of integer. constants.push_back(dialect->materializeConstant( - b, b.getIndexAttr(ofr.get().cast().getInt()), + b, + b.getIndexAttr(llvm::cast(ofr.get()).getInt()), b.getIndexType(), loc)); actualValues.push_back(constants.back()->getResult(0)); } @@ -1785,11 +1786,11 @@ } LogicalResult AffineDmaStartOp::verifyInvariantsImpl() { - if (!getOperand(getSrcMemRefOperandIndex()).getType().isa()) + if (!llvm::isa(getOperand(getSrcMemRefOperandIndex()).getType())) return emitOpError("expected DMA source to be of memref type"); - if (!getOperand(getDstMemRefOperandIndex()).getType().isa()) + if (!llvm::isa(getOperand(getDstMemRefOperandIndex()).getType())) return emitOpError("expected DMA destination to be of memref type"); - if (!getOperand(getTagMemRefOperandIndex()).getType().isa()) + if (!llvm::isa(getOperand(getTagMemRefOperandIndex()).getType())) return emitOpError("expected DMA tag to be of memref type"); unsigned numInputsAllMaps = getSrcMap().getNumInputs() + @@ -1888,7 +1889,7 @@ parser.resolveOperand(numElementsInfo, indexType, result.operands)) return failure(); - if (!type.isa()) + if (!llvm::isa(type)) return parser.emitError(parser.getNameLoc(), "expected tag to be of memref type"); @@ -1899,7 +1900,7 @@ } LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() { - if (!getOperand(0).getType().isa()) + if (!llvm::isa(getOperand(0).getType())) return emitOpError("expected DMA tag to be of memref type"); Region *scope = getAffineScope(*this); for (auto idx : getTagIndices()) { @@ -2073,7 +2074,7 @@ return failure(); // Parse full form - affine map followed by dim and symbol list. - if (auto affineMapAttr = boundAttr.dyn_cast()) { + if (auto affineMapAttr = llvm::dyn_cast(boundAttr)) { unsigned currentNumOperands = result.operands.size(); unsigned numDims; if (parseDimAndSymbolList(p, result.operands, numDims)) @@ -2106,7 +2107,7 @@ } // Parse custom assembly form. - if (auto integerAttr = boundAttr.dyn_cast()) { + if (auto integerAttr = llvm::dyn_cast(boundAttr)) { result.attributes.pop_back(); result.addAttribute( boundAttrStrName, @@ -2296,9 +2297,9 @@ // Compute the max or min as applicable over the results. assert(!foldedResults.empty() && "bounds should have at least one result"); - auto maxOrMin = foldedResults[0].cast().getValue(); + auto maxOrMin = llvm::cast(foldedResults[0]).getValue(); for (unsigned i = 1, e = foldedResults.size(); i < e; i++) { - auto foldedResult = foldedResults[i].cast().getValue(); + auto foldedResult = llvm::cast(foldedResults[i]).getValue(); maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult) : llvm::APIntOps::smin(maxOrMin, foldedResult); } @@ -2653,7 +2654,7 @@ } AffineForOp mlir::affine::getForInductionVarOwner(Value val) { - auto ivArg = val.dyn_cast(); + auto ivArg = llvm::dyn_cast(val); if (!ivArg || !ivArg.getOwner()) return AffineForOp(); auto *containingInst = ivArg.getOwner()->getParent()->getParentOp(); @@ -2664,7 +2665,7 @@ } AffineParallelOp mlir::affine::getAffineParallelInductionVarOwner(Value val) { - auto ivArg = val.dyn_cast(); + auto ivArg = llvm::dyn_cast(val); if (!ivArg || !ivArg.getOwner()) return nullptr; Operation *containingOp = ivArg.getOwner()->getParentOp(); @@ -3113,7 +3114,7 @@ result.addOperands(operands); if (map) result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map)); - auto memrefType = operands[0].getType().cast(); + auto memrefType = llvm::cast(operands[0].getType()); result.types.push_back(memrefType.getElementType()); } @@ -3122,14 +3123,14 @@ assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); result.addOperands(memref); result.addOperands(mapOperands); - auto memrefType = memref.getType().cast(); + auto memrefType = llvm::cast(memref.getType()); result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map)); result.types.push_back(memrefType.getElementType()); } void AffineLoadOp::build(OpBuilder &builder, OperationState &result, Value memref, ValueRange indices) { - auto memrefType = memref.getType().cast(); + auto memrefType = llvm::cast(memref.getType()); int64_t rank = memrefType.getRank(); // Create identity map for memrefs with at least one dimension or () -> () // for zero-dimensional memrefs. @@ -3238,11 +3239,11 @@ // Check if the global memref is a constant. auto cstAttr = - global.getConstantInitValue().dyn_cast_or_null(); + llvm::dyn_cast_or_null(global.getConstantInitValue()); if (!cstAttr) return {}; // If it's a splat constant, we can fold irrespective of indices. - if (auto splatAttr = cstAttr.dyn_cast()) + if (auto splatAttr = llvm::dyn_cast(cstAttr)) return splatAttr.getSplatValue(); // Otherwise, we can fold only if we know the indices. if (!getAffineMap().isConstant()) @@ -3271,7 +3272,7 @@ void AffineStoreOp::build(OpBuilder &builder, OperationState &result, Value valueToStore, Value memref, ValueRange indices) { - auto memrefType = memref.getType().cast(); + auto memrefType = llvm::cast(memref.getType()); int64_t rank = memrefType.getRank(); // Create identity map for memrefs with at least one dimension or () -> () // for zero-dimensional memrefs. @@ -4017,7 +4018,7 @@ // Verify reduction ops are all valid for (Attribute attr : getReductions()) { - auto intAttr = attr.dyn_cast(); + auto intAttr = llvm::dyn_cast(attr); if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt())) return emitOpError("invalid reduction attribute"); } @@ -4119,7 +4120,7 @@ p << " reduce ("; llvm::interleaveComma(getReductions(), p, [&](auto &attr) { arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind( - attr.template cast().getInt()); + llvm::cast(attr).getInt()); p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\""; }); p << ") -> (" << getResultTypes() << ")"; @@ -4429,7 +4430,7 @@ void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result, VectorType resultType, Value memref, ValueRange indices) { - auto memrefType = memref.getType().cast(); + auto memrefType = llvm::cast(memref.getType()); int64_t rank = memrefType.getRank(); // Create identity map for memrefs with at least one dimension or () -> () // for zero-dimensional memrefs. @@ -4520,7 +4521,7 @@ void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result, Value valueToStore, Value memref, ValueRange indices) { - auto memrefType = memref.getType().cast(); + auto memrefType = llvm::cast(memref.getType()); int64_t rank = memrefType.getRank(); // Create identity map for memrefs with at least one dimension or () -> () // for zero-dimensional memrefs. diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -289,7 +289,7 @@ // memref type. Call Op that returns one or more memref type results // is already taken care of, by the previous conditions. if (llvm::any_of(op.getOperandTypes(), - [&](Type t) { return t.isa(); })) { + [&](Type t) { return llvm::isa(t); })) { Node node(nextNodeId++, &op); nodes.insert({node.id, node}); } @@ -379,7 +379,7 @@ OpBuilder top(forInst->getParentRegion()); // Create new memref type based on slice bounds. auto oldMemRef = cast(srcStoreOpInst).getMemRef(); - auto oldMemRefType = oldMemRef.getType().cast(); + auto oldMemRefType = llvm::cast(oldMemRef.getType()); unsigned rank = oldMemRefType.getRank(); // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'. @@ -516,7 +516,7 @@ return WalkResult::advance(); for (Value v : op->getOperands()) // Collect memref values only. - if (v.getType().isa()) + if (llvm::isa(v.getType())) memRefValues.insert(v); return WalkResult::advance(); }); diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp --- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp @@ -88,7 +88,7 @@ return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({}); }; - auto oldMemRefType = oldMemRef.getType().cast(); + auto oldMemRefType = llvm::cast(oldMemRef.getType()); auto newMemRefType = doubleShape(oldMemRefType); // The double buffer is allocated right before 'forOp'. diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -1852,7 +1852,8 @@ int64_t numEltPerStride = 1; int64_t stride = 1; for (int d = bufferShape.size() - 1; d >= 1; d--) { - int64_t dimSize = region.memref.getType().cast().getDimSize(d); + int64_t dimSize = + llvm::cast(region.memref.getType()).getDimSize(d); stride *= dimSize; numEltPerStride *= bufferShape[d]; // A stride is needed only if the region has a shorter extent than the @@ -1891,7 +1892,7 @@ return ubMap.getNumInputs() == ubOperands.size(); })); - unsigned rank = memref.getType().cast().getRank(); + unsigned rank = llvm::cast(memref.getType()).getRank(); assert(lbMaps.size() == rank && "wrong number of lb maps"); assert(ubMaps.size() == rank && "wrong number of ub maps"); @@ -2003,7 +2004,7 @@ auto loc = region.loc; auto memref = region.memref; - auto memRefType = memref.getType().cast(); + auto memRefType = llvm::cast(memref.getType()); if (!memRefType.getLayout().isIdentity()) { LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); @@ -2276,7 +2277,7 @@ assert(false && "expected load or store op"); return false; } - auto memRefType = region->memref.getType().cast(); + auto memRefType = llvm::cast(region->memref.getType()); if (!memRefType.hasStaticShape()) return false; diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -1119,9 +1119,11 @@ ArrayRef extraIndices, AffineMap indexRemap, ArrayRef extraOperands, ArrayRef symbolOperands, bool allowNonDereferencingOps) { - unsigned newMemRefRank = newMemRef.getType().cast().getRank(); + unsigned newMemRefRank = + llvm::cast(newMemRef.getType()).getRank(); (void)newMemRefRank; // unused in opt mode - unsigned oldMemRefRank = oldMemRef.getType().cast().getRank(); + unsigned oldMemRefRank = + llvm::cast(oldMemRef.getType()).getRank(); (void)oldMemRefRank; // unused in opt mode if (indexRemap) { assert(indexRemap.getNumSymbols() == symbolOperands.size() && @@ -1134,8 +1136,8 @@ } // Assert same elemental type. - assert(oldMemRef.getType().cast().getElementType() == - newMemRef.getType().cast().getElementType()); + assert(llvm::cast(oldMemRef.getType()).getElementType() == + llvm::cast(newMemRef.getType()).getElementType()); SmallVector usePositions; for (const auto &opEntry : llvm::enumerate(op->getOperands())) { @@ -1172,7 +1174,8 @@ // Perform index rewrites for the dereferencing op and then replace the op NamedAttribute oldMapAttrPair = affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef); - AffineMap oldMap = oldMapAttrPair.getValue().cast().getValue(); + AffineMap oldMap = + llvm::cast(oldMapAttrPair.getValue()).getValue(); unsigned oldMapNumInputs = oldMap.getNumInputs(); SmallVector oldMapOperands( op->operand_begin() + memRefOperandPos + 1, @@ -1294,9 +1297,11 @@ ArrayRef symbolOperands, Operation *domOpFilter, Operation *postDomOpFilter, bool allowNonDereferencingOps, bool replaceInDeallocOp) { - unsigned newMemRefRank = newMemRef.getType().cast().getRank(); + unsigned newMemRefRank = + llvm::cast(newMemRef.getType()).getRank(); (void)newMemRefRank; // unused in opt mode - unsigned oldMemRefRank = oldMemRef.getType().cast().getRank(); + unsigned oldMemRefRank = + llvm::cast(oldMemRef.getType()).getRank(); (void)oldMemRefRank; if (indexRemap) { assert(indexRemap.getNumSymbols() == symbolOperands.size() && @@ -1309,8 +1314,8 @@ } // Assert same elemental type. - assert(oldMemRef.getType().cast().getElementType() == - newMemRef.getType().cast().getElementType()); + assert(llvm::cast(oldMemRef.getType()).getElementType() == + llvm::cast(newMemRef.getType()).getElementType()); std::unique_ptr domInfo; std::unique_ptr postDomInfo; @@ -1734,7 +1739,7 @@ SmallVector> tileSizePos; (void)getTileSizePos(layoutMap, tileSizePos); if (newMemRefType.getNumDynamicDims() > 0 && !tileSizePos.empty()) { - MemRefType oldMemRefType = oldMemRef.getType().cast(); + MemRefType oldMemRefType = llvm::cast(oldMemRef.getType()); SmallVector newDynamicSizes; createNewDynamicSizes(oldMemRefType, newMemRefType, layoutMap, allocOp, b, newDynamicSizes); diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -36,15 +36,15 @@ static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { return builder.getIntegerAttr(res.getType(), - lhs.cast().getInt() + - rhs.cast().getInt()); + llvm::cast(lhs).getInt() + + llvm::cast(rhs).getInt()); } static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { return builder.getIntegerAttr(res.getType(), - lhs.cast().getInt() - - rhs.cast().getInt()); + llvm::cast(lhs).getInt() - + llvm::cast(rhs).getInt()); } /// Invert an integer comparison predicate. @@ -92,11 +92,11 @@ } static FailureOr getIntOrSplatIntValue(Attribute attr) { - if (auto intAttr = attr.dyn_cast()) + if (auto intAttr = llvm::dyn_cast(attr)) return intAttr.getValue(); - if (auto splatAttr = attr.dyn_cast()) - if (splatAttr.getElementType().isa()) + if (auto splatAttr = llvm::dyn_cast(attr)) + if (llvm::isa(splatAttr.getElementType())) return splatAttr.getSplatValue(); return failure(); @@ -117,11 +117,11 @@ /// Return the type of the same shape (scalar, vector or tensor) containing i1. static Type getI1SameShape(Type type) { auto i1Type = IntegerType::get(type.getContext(), 1); - if (auto tensorType = type.dyn_cast()) + if (auto tensorType = llvm::dyn_cast(type)) return RankedTensorType::get(tensorType.getShape(), i1Type); - if (type.isa()) + if (llvm::isa(type)) return UnrankedTensorType::get(i1Type); - if (auto vectorType = type.dyn_cast()) + if (auto vectorType = llvm::dyn_cast(type)) return VectorType::get(vectorType.getShape(), i1Type, vectorType.getNumScalableDims()); return i1Type; @@ -134,8 +134,8 @@ void arith::ConstantOp::getAsmResultNames( function_ref setNameFn) { auto type = getType(); - if (auto intCst = getValue().dyn_cast()) { - auto intType = type.dyn_cast(); + if (auto intCst = llvm::dyn_cast(getValue())) { + auto intType = llvm::dyn_cast(type); // Sugar i1 constants with 'true' and 'false'. if (intType && intType.getWidth() == 1) @@ -163,10 +163,11 @@ << " must match return type: " << type; } // Integer values must be signless. - if (type.isa() && !type.cast().isSignless()) + if (llvm::isa(type) && + !llvm::cast(type).isSignless()) return emitOpError("integer return type must be signless"); // Any float or elements attribute are acceptable. - if (!getValue().isa()) { + if (!llvm::isa(getValue())) { return emitOpError( "value must be an integer, float, or elements attribute"); } @@ -175,14 +176,15 @@ bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { // The value's type must be the same as the provided type. - auto typedAttr = value.dyn_cast(); + auto typedAttr = llvm::dyn_cast(value); if (!typedAttr || typedAttr.getType() != type) return false; // Integer values must be signless. - if (type.isa() && !type.cast().isSignless()) + if (llvm::isa(type) && + !llvm::cast(type).isSignless()) return false; // Integer, float, and element attributes are buildable. - return value.isa(); + return llvm::isa(value); } ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value, @@ -223,7 +225,7 @@ bool arith::ConstantFloatOp::classof(Operation *op) { if (auto constOp = dyn_cast_or_null(op)) - return constOp.getType().isa(); + return llvm::isa(constOp.getType()); return false; } @@ -275,7 +277,7 @@ std::optional> arith::AddUIExtendedOp::getShapeForUnroll() { - if (auto vt = getType(0).dyn_cast()) + if (auto vt = llvm::dyn_cast(getType(0))) return llvm::to_vector<4>(vt.getShape()); return std::nullopt; } @@ -309,7 +311,7 @@ [](APInt a, const APInt &b) { return std::move(a) + b; })) { Attribute overflowAttr = constFoldBinaryOp( ArrayRef({sumAttr, adaptor.getLhs()}), - getI1SameShape(sumAttr.cast().getType()), + getI1SameShape(llvm::cast(sumAttr).getType()), calculateUnsignedOverflow); if (!overflowAttr) return failure(); @@ -385,7 +387,7 @@ std::optional> arith::MulSIExtendedOp::getShapeForUnroll() { - if (auto vt = getType(0).dyn_cast()) + if (auto vt = llvm::dyn_cast(getType(0))) return llvm::to_vector<4>(vt.getShape()); return std::nullopt; } @@ -433,7 +435,7 @@ std::optional> arith::MulUIExtendedOp::getShapeForUnroll() { - if (auto vt = getType(0).dyn_cast()) + if (auto vt = llvm::dyn_cast(getType(0))) return llvm::to_vector<4>(vt.getShape()); return std::nullopt; } @@ -1093,11 +1095,11 @@ template static Type getUnderlyingType(Type type, type_list, type_list) { - if (type.isa() && !type.isa()) + if (llvm::isa(type) && !llvm::isa(type)) return {}; auto underlyingType = getElementTypeOrSelf(type); - if (!underlyingType.isa()) + if (!llvm::isa(underlyingType)) return {}; return underlyingType; @@ -1133,7 +1135,8 @@ Type srcType = getElementTypeOrSelf(op.getIn().getType()); Type dstType = getElementTypeOrSelf(op.getType()); - if (srcType.cast().getWidth() >= dstType.cast().getWidth()) + if (llvm::cast(srcType).getWidth() >= + llvm::cast(dstType).getWidth()) return op.emitError("result type ") << dstType << " must be wider than operand type " << srcType; @@ -1146,7 +1149,8 @@ Type srcType = getElementTypeOrSelf(op.getIn().getType()); Type dstType = getElementTypeOrSelf(op.getType()); - if (srcType.cast().getWidth() <= dstType.cast().getWidth()) + if (llvm::cast(srcType).getWidth() <= + llvm::cast(dstType).getWidth()) return op.emitError("result type ") << dstType << " must be shorter than operand type " << srcType; @@ -1179,7 +1183,7 @@ } Type resType = getElementTypeOrSelf(getType()); - unsigned bitWidth = resType.cast().getWidth(); + unsigned bitWidth = llvm::cast(resType).getWidth(); return constFoldCastOp( adaptor.getOperands(), getType(), [bitWidth](const APInt &a, bool &castStatus) { @@ -1206,7 +1210,7 @@ } Type resType = getElementTypeOrSelf(getType()); - unsigned bitWidth = resType.cast().getWidth(); + unsigned bitWidth = llvm::cast(resType).getWidth(); return constFoldCastOp( adaptor.getOperands(), getType(), [bitWidth](const APInt &a, bool &castStatus) { @@ -1259,8 +1263,8 @@ Type dstType = getElementTypeOrSelf(getType()); // trunci(zexti(a)) -> trunci(a) // trunci(sexti(a)) -> trunci(a) - if (srcType.cast().getWidth() > - dstType.cast().getWidth()) { + if (llvm::cast(srcType).getWidth() > + llvm::cast(dstType).getWidth()) { setOperand(src); return getResult(); } @@ -1276,7 +1280,7 @@ } Type resType = getElementTypeOrSelf(getType()); - unsigned bitWidth = resType.cast().getWidth(); + unsigned bitWidth = llvm::cast(resType).getWidth(); return constFoldCastOp( adaptor.getOperands(), getType(), [bitWidth](const APInt &a, bool &castStatus) { @@ -1307,12 +1311,12 @@ /// can be represented without precision loss or rounding. OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) { auto constOperand = adaptor.getIn(); - if (!constOperand || !constOperand.isa()) + if (!constOperand || !llvm::isa(constOperand)) return {}; // Convert to target type via 'double'. double sourceValue = - constOperand.dyn_cast().getValue().convertToDouble(); + llvm::dyn_cast(constOperand).getValue().convertToDouble(); auto targetAttr = FloatAttr::get(getType(), sourceValue); // Propagate if constant's value does not change after truncation. @@ -1376,7 +1380,7 @@ return constFoldCastOp( adaptor.getOperands(), getType(), [&resEleType](const APInt &a, bool &castStatus) { - FloatType floatTy = resEleType.cast(); + FloatType floatTy = llvm::cast(resEleType); APFloat apf(floatTy.getFloatSemantics(), APInt::getZero(floatTy.getWidth())); apf.convertFromAPInt(a, /*IsSigned=*/false, @@ -1398,7 +1402,7 @@ return constFoldCastOp( adaptor.getOperands(), getType(), [&resEleType](const APInt &a, bool &castStatus) { - FloatType floatTy = resEleType.cast(); + FloatType floatTy = llvm::cast(resEleType); APFloat apf(floatTy.getFloatSemantics(), APInt::getZero(floatTy.getWidth())); apf.convertFromAPInt(a, /*IsSigned=*/true, @@ -1416,7 +1420,7 @@ OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) { Type resType = getElementTypeOrSelf(getType()); - unsigned bitWidth = resType.cast().getWidth(); + unsigned bitWidth = llvm::cast(resType).getWidth(); return constFoldCastOp( adaptor.getOperands(), getType(), [&bitWidth](const APFloat &a, bool &castStatus) { @@ -1438,7 +1442,7 @@ OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) { Type resType = getElementTypeOrSelf(getType()); - unsigned bitWidth = resType.cast().getWidth(); + unsigned bitWidth = llvm::cast(resType).getWidth(); return constFoldCastOp( adaptor.getOperands(), getType(), [&bitWidth](const APFloat &a, bool &castStatus) { @@ -1542,18 +1546,18 @@ return {}; /// Bitcast dense elements. - if (auto denseAttr = operand.dyn_cast_or_null()) - return denseAttr.bitcast(resType.cast().getElementType()); + if (auto denseAttr = llvm::dyn_cast_or_null(operand)) + return denseAttr.bitcast(llvm::cast(resType).getElementType()); /// Other shaped types unhandled. - if (resType.isa()) + if (llvm::isa(resType)) return {}; /// Bitcast integer or float to integer or float. - APInt bits = operand.isa() - ? operand.cast().getValue().bitcastToAPInt() - : operand.cast().getValue(); + APInt bits = llvm::isa(operand) + ? llvm::cast(operand).getValue().bitcastToAPInt() + : llvm::cast(operand).getValue(); - if (auto resFloatType = resType.dyn_cast()) + if (auto resFloatType = llvm::dyn_cast(resType)) return FloatAttr::get(resType, APFloat(resFloatType.getFloatSemantics(), bits)); return IntegerAttr::get(resType, bits); @@ -1618,18 +1622,18 @@ static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { auto boolAttr = BoolAttr::get(ctx, value); - ShapedType shapedType = type.dyn_cast_or_null(); + ShapedType shapedType = llvm::dyn_cast_or_null(type); if (!shapedType) return boolAttr; return DenseElementsAttr::get(shapedType, boolAttr); } static std::optional getIntegerWidth(Type t) { - if (auto intType = t.dyn_cast()) { + if (auto intType = llvm::dyn_cast(t)) { return intType.getWidth(); } - if (auto vectorIntType = t.dyn_cast()) { - return vectorIntType.getElementType().cast().getWidth(); + if (auto vectorIntType = llvm::dyn_cast(t)) { + return llvm::cast(vectorIntType.getElementType()).getWidth(); } return std::nullopt; } @@ -1817,7 +1821,7 @@ // Get the width of the mantissa. We don't want to hack on conversions that // might lose information from the integer, e.g. "i64 -> float" - FloatType floatTy = op.getRhs().getType().cast(); + FloatType floatTy = llvm::cast(op.getRhs().getType()); int mantissaWidth = floatTy.getFPMantissaWidth(); if (mantissaWidth <= 0) return failure(); @@ -1837,7 +1841,7 @@ // Check to see that the input is converted from an integer type that is // small enough that preserves all bits. - auto intTy = intVal.getType().cast(); + auto intTy = llvm::cast(intVal.getType()); auto intWidth = intTy.getWidth(); // Number of bits representing values, as opposed to the sign @@ -2103,7 +2107,7 @@ LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override { // Cannot extui i1 to i1, or i1 to f32 - if (!op.getType().isa() || op.getType().isInteger(1)) + if (!llvm::isa(op.getType()) || op.getType().isInteger(1)) return failure(); // select %x, c1, %c0 => extui %arg @@ -2230,7 +2234,8 @@ p << " " << getOperands(); p.printOptionalAttrDict((*this)->getAttrs()); p << " : "; - if (ShapedType condType = getCondition().getType().dyn_cast()) + if (ShapedType condType = + llvm::dyn_cast(getCondition().getType())) p << condType << ", "; p << getType(); } @@ -2243,7 +2248,7 @@ // If the result type is a vector or tensor, the type can be a mask with the // same elements. Type resultType = getType(); - if (!resultType.isa()) + if (!llvm::isa(resultType)) return emitOpError() << "expected condition to be a signless i1, but got " << conditionType; Type shapedConditionType = getI1SameShape(resultType); @@ -2320,7 +2325,7 @@ case AtomicRMWKind::maxf: return builder.getFloatAttr( resultType, - APFloat::getInf(resultType.cast().getFloatSemantics(), + APFloat::getInf(llvm::cast(resultType).getFloatSemantics(), /*Negative=*/true)); case AtomicRMWKind::addf: case AtomicRMWKind::addi: @@ -2330,24 +2335,24 @@ case AtomicRMWKind::andi: return builder.getIntegerAttr( resultType, - APInt::getAllOnes(resultType.cast().getWidth())); + APInt::getAllOnes(llvm::cast(resultType).getWidth())); case AtomicRMWKind::maxs: return builder.getIntegerAttr( - resultType, - APInt::getSignedMinValue(resultType.cast().getWidth())); + resultType, APInt::getSignedMinValue( + llvm::cast(resultType).getWidth())); case AtomicRMWKind::minf: return builder.getFloatAttr( resultType, - APFloat::getInf(resultType.cast().getFloatSemantics(), + APFloat::getInf(llvm::cast(resultType).getFloatSemantics(), /*Negative=*/false)); case AtomicRMWKind::mins: return builder.getIntegerAttr( - resultType, - APInt::getSignedMaxValue(resultType.cast().getWidth())); + resultType, APInt::getSignedMaxValue( + llvm::cast(resultType).getWidth())); case AtomicRMWKind::minu: return builder.getIntegerAttr( resultType, - APInt::getMaxValue(resultType.cast().getWidth())); + APInt::getMaxValue(llvm::cast(resultType).getWidth())); case AtomicRMWKind::muli: return builder.getIntegerAttr(resultType, 1); case AtomicRMWKind::mulf: diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp --- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp +++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp @@ -25,7 +25,7 @@ void arith::ConstantOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - auto constAttr = getValue().dyn_cast_or_null(); + auto constAttr = llvm::dyn_cast_or_null(getValue()); if (constAttr) { const APInt &value = constAttr.getValue(); setResultRange(getResult(), ConstantIntRanges::constant(value)); diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp @@ -37,7 +37,7 @@ auto constantOp = cast(op); assert(value == constantOp.getResult() && "invalid value"); - if (auto attr = constantOp.getValue().dyn_cast()) + if (auto attr = llvm::dyn_cast(constantOp.getValue())) cstr.bound(value) == attr.getInt(); } }; diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -34,7 +34,7 @@ return constantOp->emitError("could not infer memory space"); // Only ranked tensors are supported. - if (!constantOp.getType().isa()) + if (!llvm::isa(constantOp.getType())) return failure(); // Only constants inside a module are supported. @@ -58,7 +58,7 @@ bool isWritable(Operation *op, Value value, const AnalysisState &state) const { // Memory locations returned by memref::GetGlobalOp may not be written to. - assert(value.isa()); + assert(llvm::isa(value)); return false; } }; @@ -84,21 +84,21 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto castOp = cast(op); - auto resultTensorType = castOp.getType().cast(); + auto resultTensorType = llvm::cast(castOp.getType()); FailureOr source = getBuffer(rewriter, castOp.getIn(), options); if (failed(source)) return failure(); - auto sourceType = source->getType().cast(); + auto sourceType = llvm::cast(source->getType()); // Result type should have same layout and address space as the source type. BaseMemRefType resultType; - if (auto rankedMemRefType = sourceType.dyn_cast()) { + if (auto rankedMemRefType = llvm::dyn_cast(sourceType)) { resultType = MemRefType::get( rankedMemRefType.getShape(), resultTensorType.getElementType(), rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace()); } else { - auto unrankedMemrefType = sourceType.cast(); + auto unrankedMemrefType = llvm::cast(sourceType); resultType = UnrankedMemRefType::get(resultTensorType.getElementType(), unrankedMemrefType.getMemorySpace()); } diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp --- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp @@ -86,12 +86,12 @@ continue; } - assert(value.getType().cast().isDynamicDim(*dim) && + assert(llvm::cast(value.getType()).isDynamicDim(*dim) && "expected dynamic dim"); - if (value.getType().isa()) { + if (llvm::isa(value.getType())) { // A tensor dimension is used: generate a tensor.dim. operands.push_back(b.create(loc, value, *dim)); - } else if (value.getType().isa()) { + } else if (llvm::isa(value.getType())) { // A memref dimension is used: generate a memref.dim. operands.push_back(b.create(loc, value, *dim)); } else { diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp --- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp +++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp @@ -28,7 +28,7 @@ /// Return the scalable vector of the same shape and containing i1. static Type getI1SameShape(Type type) { auto i1Type = IntegerType::get(type.getContext(), 1); - if (auto sVectorType = type.dyn_cast()) + if (auto sVectorType = llvm::dyn_cast(type)) return VectorType::get(sVectorType.getShape(), i1Type, sVectorType.getNumScalableDims()); return nullptr; diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -42,7 +42,7 @@ auto executeOp = (*this)->getParentOfType(); auto types = llvm::map_range(executeOp.getBodyResults(), [](const OpResult &result) { - return result.getType().cast().getValueType(); + return llvm::cast(result.getType()).getValueType(); }); if (getOperandTypes() != types) @@ -71,7 +71,7 @@ bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) { const auto getValueOrTokenType = [](Type type) { - if (auto value = type.dyn_cast()) + if (auto value = llvm::dyn_cast(type)) return value.getValueType(); return type; }; @@ -118,7 +118,7 @@ bodyRegion->push_back(new Block); Block &bodyBlock = bodyRegion->front(); for (Value operand : operands) { - auto valueType = operand.getType().dyn_cast(); + auto valueType = llvm::dyn_cast(operand.getType()); bodyBlock.addArgument(valueType ? valueType.getValueType() : operand.getType(), operand.getLoc()); @@ -195,7 +195,7 @@ parser.parseColonType(valueTypes.emplace_back())) return failure(); - auto valueTy = valueTypes.back().dyn_cast(); + auto valueTy = llvm::dyn_cast(valueTypes.back()); unwrappedArgs.back().type = valueTy ? valueTy.getValueType() : Type(); return success(); }; @@ -234,7 +234,7 @@ LogicalResult ExecuteOp::verifyRegions() { // Unwrap async.execute value operands types. auto unwrappedTypes = llvm::map_range(getBodyOperands(), [](Value operand) { - return operand.getType().cast().getValueType(); + return llvm::cast(operand.getType()).getValueType(); }); // Verify that unwrapped argument types matches the body region arguments. @@ -285,7 +285,7 @@ result.attributes.append(attrs.begin(), attrs.end()); // Add unwrapped async.value type to the returned values types. - if (auto valueType = operand.getType().dyn_cast()) + if (auto valueType = llvm::dyn_cast(operand.getType())) result.addTypes(valueType.getValueType()); } @@ -295,7 +295,7 @@ return failure(); // Add unwrapped async.value type to the returned values types. - if (auto valueType = operandType.dyn_cast()) + if (auto valueType = llvm::dyn_cast(operandType)) resultType = valueType.getValueType(); return success(); @@ -310,11 +310,11 @@ Type argType = getOperand().getType(); // Awaiting on a token does not have any results. - if (argType.isa() && !getResultTypes().empty()) + if (llvm::isa(argType) && !getResultTypes().empty()) return emitOpError("awaiting on a token must have empty result"); // Awaiting on a value unwraps the async value type. - if (auto value = argType.dyn_cast()) { + if (auto value = llvm::dyn_cast(argType)) { if (*getResultType() != value.getValueType()) return emitOpError() << "result type " << *getResultType() << " does not match async value type " @@ -375,12 +375,12 @@ for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) { auto type = resultTypes[i]; - if (!type.isa() && !type.isa()) + if (!llvm::isa(type) && !llvm::isa(type)) return emitOpError() << "result type must be async value type or async " "token type, but got " << type; // We only allow AsyncToken appear as the first return value - if (type.isa() && i != 0) { + if (llvm::isa(type) && i != 0) { return emitOpError() << " results' (optional) async token type is expected " "to appear as the 1st return value, but got " @@ -446,7 +446,7 @@ // Get the underlying value types from async types returned from the // parent `async.func` operation. auto types = llvm::map_range(resultTypes, [](const Type &result) { - return result.cast().getValueType(); + return llvm::cast(result).getValueType(); }); if (getOperandTypes() != types) diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -91,9 +91,9 @@ } Operation *bufferization::getOwnerOfValue(Value value) { - if (auto opResult = value.dyn_cast()) + if (auto opResult = llvm::dyn_cast(value)) return opResult.getDefiningOp(); - return value.cast().getOwner()->getParentOp(); + return llvm::cast(value).getOwner()->getParentOp(); } bool bufferization::allocationDoesNotEscape(OpResult opResult) { @@ -109,7 +109,7 @@ return false; auto attr = op->getAttrOfType(BufferizationDialect::kEscapeAttrName); - return !attr[opResult.getResultNumber()].cast().getValue(); + return !llvm::cast(attr[opResult.getResultNumber()]).getValue(); } /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the @@ -119,31 +119,31 @@ OpBuilder &b, Location loc, Value shapedValue, bool escape, const BufferizationOptions &options, bool copy) { Value tensor; - if (shapedValue.getType().isa()) { + if (llvm::isa(shapedValue.getType())) { tensor = shapedValue; - } else if (shapedValue.getType().isa()) { + } else if (llvm::isa(shapedValue.getType())) { tensor = b.create(loc, shapedValue); - } else if (shapedValue.getType().isa() || - shapedValue.getType().isa()) { + } else if (llvm::isa(shapedValue.getType()) || + llvm::isa(shapedValue.getType())) { return getOwnerOfValue(shapedValue) ->emitError("copying of unranked tensors is not implemented"); } else { llvm_unreachable("expected RankedTensorType or MemRefType"); } - RankedTensorType tensorType = tensor.getType().cast(); + RankedTensorType tensorType = llvm::cast(tensor.getType()); SmallVector dynamicSizes; if (!copy) { // Compute the dynamic part of the shape. // First try to query the shape via ReifyRankedShapedTypeOpInterface. bool reifiedShapes = false; - if (shapedValue.getType().isa() && - shapedValue.isa()) { + if (llvm::isa(shapedValue.getType()) && + llvm::isa(shapedValue)) { ReifiedRankedShapedTypeDims resultDims; if (succeeded( reifyResultShapes(b, shapedValue.getDefiningOp(), resultDims))) { reifiedShapes = true; auto &shape = - resultDims[shapedValue.cast().getResultNumber()]; + resultDims[llvm::cast(shapedValue).getResultNumber()]; for (const auto &dim : enumerate(tensorType.getShape())) if (ShapedType::isDynamic(dim.value())) dynamicSizes.push_back(shape[dim.index()].get()); @@ -188,11 +188,11 @@ // Find all out-of-place OpOperands. for (OpOperand &opOperand : op->getOpOperands()) { Type operandType = opOperand.get().getType(); - if (!operandType.isa()) + if (!llvm::isa(operandType)) continue; if (state.isInPlace(opOperand)) continue; - if (operandType.isa()) + if (llvm::isa(operandType)) return op->emitError("copying of unranked tensors is not implemented"); AliasingOpResultList aliasingOpResults = @@ -209,9 +209,8 @@ !state.bufferizesToMemoryWrite(opOperand) && state.getAliasingOpOperands(aliasingOpResults.getAliases()[0].opResult) .getNumAliases() == 1 && - !aliasingOpResults.getAliases()[0] - .opResult.getType() - .isa()) { + !llvm::isa( + aliasingOpResults.getAliases()[0].opResult.getType())) { // The op itself does not write but may create exactly one alias. Instead // of copying the OpOperand, copy the OpResult. The OpResult can sometimes // be smaller than the OpOperand (e.g., in the case of an extract_slice, @@ -281,9 +280,9 @@ AnalysisState analysisState(options); if (op->hasAttr(BufferizationDialect::kEscapeAttrName)) { // AllocTensorOp has one result. - ArrayAttr escapeAttr = - op->getAttr(BufferizationDialect::kEscapeAttrName).cast(); - return !escapeAttr[0].cast().getValue(); + ArrayAttr escapeAttr = llvm::cast( + op->getAttr(BufferizationDialect::kEscapeAttrName)); + return !llvm::cast(escapeAttr[0]).getValue(); } // No "escape" annotation found. @@ -335,8 +334,8 @@ BaseMemRefType defaultUnknownTypeConverter(Value value, Attribute memorySpace, const BufferizationOptions &options) { - return getMemRefTypeWithFullyDynamicLayout(value.getType().cast(), - memorySpace); + return getMemRefTypeWithFullyDynamicLayout( + llvm::cast(value.getType()), memorySpace); } } // namespace @@ -394,7 +393,7 @@ //===----------------------------------------------------------------------===// static void setInsertionPointAfter(OpBuilder &b, Value value) { - if (auto bbArg = value.dyn_cast()) { + if (auto bbArg = llvm::dyn_cast(value)) { b.setInsertionPointToStart(bbArg.getOwner()); } else { b.setInsertionPointAfter(value.getDefiningOp()); @@ -463,7 +462,7 @@ } bool AnalysisState::bufferizesToMemoryWrite(Value value) const { - auto opResult = value.dyn_cast(); + auto opResult = llvm::dyn_cast(value); if (!opResult) return true; auto bufferizableOp = getOptions().dynCastBufferizableOp(value); @@ -476,7 +475,7 @@ /// read. Also takes into account ops that create an alias but do not read by /// themselves (e.g., ExtractSliceOp). bool AnalysisState::isValueRead(Value value) const { - assert(value.getType().isa() && "expected TensorType"); + assert(llvm::isa(value.getType()) && "expected TensorType"); SmallVector workingSet; for (OpOperand &use : value.getUses()) workingSet.push_back(&use); @@ -512,13 +511,13 @@ continue; } - if (value.isa()) { + if (llvm::isa(value)) { if (alwaysIncludeLeaves) result.insert(value); continue; } - OpResult opResult = value.cast(); + OpResult opResult = llvm::cast(value); BufferizableOpInterface bufferizableOp = options.dynCastBufferizableOp(opResult.getDefiningOp()); AliasingOpOperandList aliases = getAliasingOpOperands(opResult); @@ -658,8 +657,8 @@ // bufferization.to_memref is not allowed to change the rank. static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { #ifndef NDEBUG - auto rankedTensorType = tensor.getType().dyn_cast(); - assert((!rankedTensorType || memrefType.cast().getRank() == + auto rankedTensorType = llvm::dyn_cast(tensor.getType()); + assert((!rankedTensorType || llvm::cast(memrefType).getRank() == rankedTensorType.getRank()) && "to_memref would be invalid: mismatching ranks"); #endif @@ -668,7 +667,7 @@ FailureOr bufferization::getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options) { #ifndef NDEBUG - auto tensorType = value.getType().dyn_cast(); + auto tensorType = llvm::dyn_cast(value.getType()); assert(tensorType && "unexpected non-tensor type"); #endif // NDEBUG @@ -699,7 +698,8 @@ FailureOr bufferization::getBufferType( Value value, const BufferizationOptions &options, const DenseMap &fixedTypes) { - assert(value.getType().isa() && "unexpected non-tensor type"); + assert(llvm::isa(value.getType()) && + "unexpected non-tensor type"); // If the `value` is in `fixedTypes`, return the mapped type. const auto &it = fixedTypes.find(value); @@ -731,11 +731,11 @@ SmallVector replacements; for (OpResult opResult : op->getOpResults()) { Value replacement = values[opResult.getResultNumber()]; - if (opResult.getType().isa()) { + if (llvm::isa(opResult.getType())) { // The OpResult is a tensor. Such values are replaced with memrefs during // bufferization. - assert((replacement.getType().isa() || - replacement.getType().isa()) && + assert((llvm::isa(replacement.getType()) || + llvm::isa(replacement.getType())) && "tensor op result should be replaced with a memref value"); // The existing uses of the OpResult still expect a tensor. Insert a // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually @@ -797,7 +797,7 @@ //===----------------------------------------------------------------------===// bool bufferization::isFunctionArgument(Value value) { - auto bbArg = value.dyn_cast(); + auto bbArg = llvm::dyn_cast(value); if (!bbArg) return false; return isa(bbArg.getOwner()->getParentOp()); @@ -807,17 +807,18 @@ const BufferizationOptions &options, MemRefLayoutAttrInterface layout, Attribute memorySpace) { - auto tensorType = value.getType().cast(); + auto tensorType = llvm::cast(value.getType()); // Case 1: Unranked memref type. - if (auto unrankedTensorType = tensorType.dyn_cast()) { + if (auto unrankedTensorType = + llvm::dyn_cast(tensorType)) { assert(!layout && "UnrankedTensorType cannot have a layout map"); return UnrankedMemRefType::get(unrankedTensorType.getElementType(), memorySpace); } // Case 2: Ranked memref type with specified layout. - auto rankedTensorType = tensorType.cast(); + auto rankedTensorType = llvm::cast(tensorType); if (layout) { return MemRefType::get(rankedTensorType.getShape(), rankedTensorType.getElementType(), layout, @@ -831,13 +832,14 @@ bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace) { // Case 1: Unranked memref type. - if (auto unrankedTensorType = tensorType.dyn_cast()) { + if (auto unrankedTensorType = + llvm::dyn_cast(tensorType)) { return UnrankedMemRefType::get(unrankedTensorType.getElementType(), memorySpace); } // Case 2: Ranked memref type. - auto rankedTensorType = tensorType.cast(); + auto rankedTensorType = llvm::cast(tensorType); int64_t dynamicOffset = ShapedType::kDynamic; SmallVector dynamicStrides(rankedTensorType.getRank(), ShapedType::kDynamic); @@ -854,13 +856,14 @@ bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace) { // Case 1: Unranked memref type. - if (auto unrankedTensorType = tensorType.dyn_cast()) { + if (auto unrankedTensorType = + llvm::dyn_cast(tensorType)) { return UnrankedMemRefType::get(unrankedTensorType.getElementType(), memorySpace); } // Case 2: Ranked memref type. - auto rankedTensorType = tensorType.cast(); + auto rankedTensorType = llvm::cast(tensorType); MemRefLayoutAttrInterface layout = {}; return MemRefType::get(rankedTensorType.getShape(), rankedTensorType.getElementType(), layout, @@ -943,7 +946,7 @@ Operation *op = opResult.getDefiningOp(); SmallVector result; for (OpOperand &opOperand : op->getOpOperands()) { - if (!opOperand.get().getType().isa()) + if (!llvm::isa(opOperand.get().getType())) continue; AliasingOpResultList aliasingOpResults = state.getAliasingOpResults(opOperand); @@ -957,15 +960,15 @@ FailureOr bufferization::detail::defaultGetBufferType( Value value, const BufferizationOptions &options, const DenseMap &fixedTypes) { - assert(value.getType().isa() && "expected tensor type"); + assert(llvm::isa(value.getType()) && "expected tensor type"); // No further analysis is possible for a block argument. - if (value.isa()) + if (llvm::isa(value)) return bufferization::getMemRefType(value, options); // Value is an OpResult. Operation *op = getOwnerOfValue(value); - auto opResult = value.cast(); + auto opResult = llvm::cast(value); AnalysisState state(options); AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult); if (aliases.getNumAliases() > 0 && @@ -1000,7 +1003,7 @@ // Conservatively assume that everything may be aliasing. AliasingOpOperandList r; for (OpOperand &operand : opResult.getDefiningOp()->getOpOperands()) - if (operand.get().getType().isa()) + if (llvm::isa(operand.get().getType())) r.addAlias({&operand, BufferRelation::Unknown, /*isDefinite=*/false}); return r; } @@ -1010,7 +1013,7 @@ // Conservatively assume that everything may be aliasing. AliasingOpResultList r; for (OpResult result : opOperand.getOwner()->getOpResults()) - if (result.getType().isa()) + if (llvm::isa(result.getType())) r.addAlias({result, BufferRelation::Unknown, /*isDefinite=*/false}); return r; } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp @@ -62,7 +62,7 @@ Operation *op, unsigned /*regionIndex*/, unsigned argIndex, NamedAttribute attr) { if (attr.getName() == kWritableAttrName) { - if (!attr.getValue().isa()) { + if (!llvm::isa(attr.getValue())) { return op->emitError() << "'" << kWritableAttrName << "' is expected to be a boolean attribute"; } @@ -75,11 +75,11 @@ return success(); } if (attr.getName() == kBufferAccessAttrName) { - if (!attr.getValue().isa()) { + if (!llvm::isa(attr.getValue())) { return op->emitError() << "'" << kBufferAccessAttrName << "' is expected to be a string attribute"; } - StringRef str = attr.getValue().cast().getValue(); + StringRef str = llvm::cast(attr.getValue()).getValue(); if (str != "none" && str != "read" && str != "write" && str != "read-write") return op->emitError() << "invalid value for '" << kBufferAccessAttrName << "'"; @@ -89,7 +89,7 @@ return success(); } if (attr.getName() == kBufferLayoutAttrName) { - if (!attr.getValue().isa()) { + if (!llvm::isa(attr.getValue())) { return op->emitError() << "'" << kBufferLayoutAttrName << "' is expected to be a affine map attribute"; } @@ -109,7 +109,7 @@ using bufferization::BufferizableOpInterface; if (attr.getName() == kEscapeAttrName) { - auto arrayAttr = attr.getValue().dyn_cast(); + auto arrayAttr = llvm::dyn_cast(attr.getValue()); if (!arrayAttr) return op->emitError() << "'" << kEscapeAttrName << "' is expected to be a bool array attribute"; @@ -124,13 +124,13 @@ << "'" << kEscapeAttrName << "' only valid on bufferizable ops"; for (const auto &it : llvm::enumerate(arrayAttr)) { auto attr = it.value(); - auto boolAttr = attr.dyn_cast(); + auto boolAttr = llvm::dyn_cast(attr); if (!boolAttr) return op->emitError() << "'" << kEscapeAttrName << "' is expected to be a bool array attribute"; if (!boolAttr.getValue()) continue; - if (!op->getResult(it.index()).getType().isa()) + if (!llvm::isa(op->getResult(it.index()).getType())) return op->emitError() << "'" << kEscapeAttrName << "' only valid for tensor results"; if (!bufferizableOp.bufferizesToAllocation(op->getOpResult(it.index()))) diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -27,7 +27,7 @@ FailureOr mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType destType) { - auto srcType = value.getType().cast(); + auto srcType = llvm::cast(value.getType()); // Element type, rank and memory space must match. if (srcType.getElementType() != destType.getElementType()) @@ -100,9 +100,9 @@ return success(); } - auto rankedSrcType = srcType.dyn_cast(); - auto rankedDestType = destType.dyn_cast(); - auto unrankedSrcType = srcType.dyn_cast(); + auto rankedSrcType = llvm::dyn_cast(srcType); + auto rankedDestType = llvm::dyn_cast(destType); + auto unrankedSrcType = llvm::dyn_cast(srcType); // Ranked memref -> Ranked memref cast. if (rankedSrcType && rankedDestType) { @@ -132,13 +132,13 @@ void mlir::bufferization::populateDynamicDimSizes( OpBuilder &b, Location loc, Value shapedValue, SmallVector &dynamicDims) { - auto shapedType = shapedValue.getType().cast(); + auto shapedType = llvm::cast(shapedValue.getType()); for (int64_t i = 0; i < shapedType.getRank(); ++i) { if (shapedType.isDynamicDim(i)) { - if (shapedType.isa()) { + if (llvm::isa(shapedType)) { dynamicDims.push_back(b.create(loc, shapedValue, i)); } else { - assert(shapedType.isa() && "expected tensor"); + assert(llvm::isa(shapedType) && "expected tensor"); dynamicDims.push_back(b.create(loc, shapedValue, i)); } } @@ -191,7 +191,7 @@ // Should the buffer be deallocated? bool dealloc = - shouldDeallocateOpResult(getResult().cast(), options); + shouldDeallocateOpResult(llvm::cast(getResult()), options); // Replace op. replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc); @@ -431,7 +431,7 @@ AllocTensorOp::getOperandSegmentSizeAttr()}); p << " : "; auto type = getResult().getType(); - if (auto validType = type.dyn_cast<::mlir::TensorType>()) + if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type)) p.printStrippedAttrOrType(validType); else p << type; @@ -620,8 +620,8 @@ toMemref.getOperand().getDefiningOp(); if (!tensorCastOperand) return failure(); - auto srcTensorType = - tensorCastOperand.getOperand().getType().dyn_cast(); + auto srcTensorType = llvm::dyn_cast( + tensorCastOperand.getOperand().getType()); if (!srcTensorType) return failure(); auto memrefType = MemRefType::get(srcTensorType.getShape(), diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -59,7 +59,7 @@ // This transform op is currently restricted to ModuleOps and function ops. // Such ops are modified in-place. - transformResults.set(getTransformed().cast(), payloadOps); + transformResults.set(llvm::cast(getTransformed()), payloadOps); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp @@ -280,7 +280,7 @@ // defined in a non-dominated block or it is defined in the same block // but the current value is not dominated by the source value. if (!dominators.dominates(definingBlock, parentBlock) || - (definingBlock == parentBlock && value.isa())) { + (definingBlock == parentBlock && llvm::isa(value))) { toProcess.emplace_back(value, parentBlock); valuesToFree.insert(value); } else if (visitedValues.insert(std::make_tuple(value, definingBlock)) @@ -307,8 +307,8 @@ // Add new allocs and additional clone operations. for (Value value : valuesToFree) { - if (failed(value.isa() - ? introduceBlockArgCopy(value.cast()) + if (failed(llvm::isa(value) + ? introduceBlockArgCopy(llvm::cast(value)) : introduceValueCopyForRegionResult(value))) return failure(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp @@ -43,7 +43,7 @@ /// exceed the stack space. static bool defaultIsSmallAlloc(Value alloc, unsigned maximumSizeInBytes, unsigned maxRankOfAllocatedMemRef) { - auto type = alloc.getType().dyn_cast(); + auto type = llvm::dyn_cast(alloc.getType()); if (!type || !alloc.getDefiningOp()) return false; if (!type.hasStaticShape()) { @@ -355,7 +355,7 @@ OpBuilder builder(startOperation); Operation *allocOp = alloc.getDefiningOp(); Operation *alloca = builder.create( - alloc.getLoc(), alloc.getType().cast(), + alloc.getLoc(), llvm::cast(alloc.getType()), allocOp->getOperands(), allocOp->getAttrs()); // Replace the original alloc by a newly created alloca. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -53,7 +53,7 @@ SmallVector erasedResultTypes; BitVector erasedResultIndices(functionType.getNumResults()); for (const auto &resultType : llvm::enumerate(functionType.getResults())) { - if (auto memrefType = resultType.value().dyn_cast()) { + if (auto memrefType = llvm::dyn_cast(resultType.value())) { if (!hasStaticIdentityLayout(memrefType) && !hasFullyDynamicLayoutMap(memrefType)) { // Only buffers with static identity layout can be allocated. These can @@ -103,7 +103,7 @@ SmallVector copyIntoOutParams; SmallVector keepAsReturnOperands; for (Value operand : op.getOperands()) { - if (operand.getType().isa()) + if (llvm::isa(operand.getType())) copyIntoOutParams.push_back(operand); else keepAsReturnOperands.push_back(operand); @@ -137,7 +137,7 @@ SmallVector replaceWithNewCallResults; SmallVector replaceWithOutParams; for (OpResult result : op.getResults()) { - if (result.getType().isa()) + if (llvm::isa(result.getType())) replaceWithOutParams.push_back(result); else replaceWithNewCallResults.push_back(result); @@ -145,13 +145,13 @@ SmallVector outParams; OpBuilder builder(op); for (Value memref : replaceWithOutParams) { - if (!memref.getType().cast().hasStaticShape()) { + if (!llvm::cast(memref.getType()).hasStaticShape()) { op.emitError() << "cannot create out param for dynamically shaped result"; didFail = true; return; } - auto memrefType = memref.getType().cast(); + auto memrefType = llvm::cast(memref.getType()); auto allocType = MemRefType::get(memrefType.getShape(), memrefType.getElementType(), AffineMap(), memrefType.getMemorySpace()); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp @@ -68,7 +68,7 @@ [=](MemoryEffects::EffectInstance &it) { Value value = it.getValue(); return isa(it.getEffect()) && value && - value.isa() && + llvm::isa(value) && it.getResource() != SideEffects::AutomaticAllocationScopeResource::get(); }); @@ -149,7 +149,7 @@ FailureOr bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, Attribute memorySpace) { - auto type = constantOp.getType().cast(); + auto type = llvm::cast(constantOp.getType()); auto moduleOp = constantOp->getParentOfType(); if (!moduleOp) return failure(); @@ -185,14 +185,14 @@ : IntegerAttr(); BufferizeTypeConverter typeConverter; - auto memrefType = typeConverter.convertType(type).cast(); + auto memrefType = llvm::cast(typeConverter.convertType(type)); if (memorySpace) memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace); auto global = globalBuilder.create( constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), /*sym_visibility=*/globalBuilder.getStringAttr("private"), /*type=*/memrefType, - /*initial_value=*/constantOp.getValue().cast(), + /*initial_value=*/llvm::cast(constantOp.getValue()), /*constant=*/true, /*alignment=*/memrefAlignment); symbolTable.insert(global); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -44,7 +44,7 @@ static Value materializeToTensor(OpBuilder &builder, TensorType type, ValueRange inputs, Location loc) { assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); + assert(llvm::isa(inputs[0].getType())); return builder.create(loc, type, inputs[0]); } @@ -66,11 +66,11 @@ ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1 && "expected exactly one input"); - if (auto inputType = inputs[0].getType().dyn_cast()) { + if (auto inputType = llvm::dyn_cast(inputs[0].getType())) { // MemRef to MemRef cast. assert(inputType != type && "expected different types"); // Unranked to ranked and ranked to unranked casts must be explicit. - auto rankedDestType = type.dyn_cast(); + auto rankedDestType = llvm::dyn_cast(type); if (!rankedDestType) return nullptr; FailureOr replacement = @@ -80,7 +80,7 @@ return *replacement; } - if (inputs[0].getType().isa()) { + if (llvm::isa(inputs[0].getType())) { // Tensor to MemRef cast. return builder.create(loc, type, inputs[0]); } @@ -222,7 +222,7 @@ parseLayoutMapOption(unknownTypeConversion); opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace, const BufferizationOptions &options) { - auto tensorType = value.getType().cast(); + auto tensorType = llvm::cast(value.getType()); if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap) return bufferization::getMemRefTypeWithStaticIdentityLayout( tensorType, memorySpace); @@ -325,7 +325,7 @@ // BufferizableOpInterface-based Bufferization //===----------------------------------------------------------------------===// -static bool isaTensor(Type t) { return t.isa(); } +static bool isaTensor(Type t) { return llvm::isa(t); } /// Return true if the given op has a tensor result or a tensor operand. static bool hasTensorSemantics(Operation *op) { @@ -549,7 +549,7 @@ options.unknownTypeConverterFn = [](Value value, Attribute memorySpace, const BufferizationOptions &options) { return getMemRefTypeWithStaticIdentityLayout( - value.getType().cast(), memorySpace); + llvm::cast(value.getType()), memorySpace); }; options.opFilter.allowDialect(); return options; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -60,7 +60,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, const BufferizationOptions &options) { auto tensorType = - funcOp.getFunctionType().getInput(index).dyn_cast(); + llvm::dyn_cast(funcOp.getFunctionType().getInput(index)); assert(tensorType && "expected TensorType"); BaseMemRefType memrefType = options.functionArgTypeConverterFn( @@ -71,7 +71,7 @@ if (!layoutAttr) return memrefType; - auto rankedMemrefType = memrefType.dyn_cast(); + auto rankedMemrefType = llvm::dyn_cast(memrefType); assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); return MemRefType::get( rankedMemrefType.getShape(), rankedMemrefType.getElementType(), @@ -224,7 +224,7 @@ for (const auto &it : llvm::enumerate(callOp.getResultTypes())) { unsigned returnValIdx = it.index(); Type returnType = it.value(); - if (!returnType.isa()) { + if (!llvm::isa(returnType)) { // Non-tensor values are returned. retValMapping[returnValIdx] = resultTypes.size(); resultTypes.push_back(returnType); @@ -242,7 +242,7 @@ Value tensorOperand = opOperand.get(); // Non-tensor operands are just copied. - if (!tensorOperand.getType().isa()) { + if (!llvm::isa(tensorOperand.getType())) { newOperands[idx] = tensorOperand; continue; } @@ -342,7 +342,7 @@ SmallVector argTypes; for (const auto &it : llvm::enumerate(funcType.getInputs())) { Type argType = it.value(); - if (auto tensorType = argType.dyn_cast()) { + if (auto tensorType = llvm::dyn_cast(argType)) { argTypes.push_back( getBufferizedFunctionArgType(funcOp, it.index(), options)); continue; @@ -356,7 +356,7 @@ if (funcOp.getBody().empty()) { SmallVector retTypes; for (Type resultType : funcType.getResults()) { - if (resultType.isa()) + if (llvm::isa(resultType)) return funcOp->emitError() << "cannot bufferize bodiless function " << "that returns a tensor"; retTypes.push_back(resultType); @@ -373,7 +373,7 @@ // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg. Block &frontBlock = funcOp.getBody().front(); for (BlockArgument &bbArg : frontBlock.getArguments()) { - auto tensorType = bbArg.getType().dyn_cast(); + auto tensorType = llvm::dyn_cast(bbArg.getType()); // Non-tensor types stay the same. if (!tensorType) continue; @@ -404,7 +404,7 @@ SmallVector returnValues; for (OpOperand &returnOperand : returnOp->getOpOperands()) { Value returnVal = returnOperand.get(); - auto tensorType = returnVal.getType().dyn_cast(); + auto tensorType = llvm::dyn_cast(returnVal.getType()); rewriter.setInsertionPoint(returnOp); // If not a tensor type just forward it. @@ -436,7 +436,7 @@ bool isWritable(Operation *op, Value value, const AnalysisState &state) const { auto funcOp = cast(op); - BlockArgument bbArg = value.dyn_cast(); + BlockArgument bbArg = llvm::dyn_cast(value); assert(bbArg && "expected BlockArgument"); // "bufferization.writable" overrides other writability decisions. This is diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -66,7 +66,7 @@ using namespace mlir; using namespace mlir::bufferization; -static bool isaTensor(Type t) { return t.isa(); } +static bool isaTensor(Type t) { return llvm::isa(t); } //===----------------------------------------------------------------------===// // Bufferization-specific attribute manipulation. @@ -85,11 +85,11 @@ SmallVector inPlaceVector; if (auto attr = op->getAttr(kInPlaceOperandsAttrName)) { inPlaceVector = SmallVector(llvm::to_vector<4>( - attr.cast().getAsValueRange())); + llvm::cast(attr).getAsValueRange())); } else { inPlaceVector = SmallVector(op->getNumOperands(), "none"); for (OpOperand &opOperand : op->getOpOperands()) - if (opOperand.get().getType().isa()) + if (llvm::isa(opOperand.get().getType())) inPlaceVector[opOperand.getOperandNumber()] = "false"; } inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false"; @@ -107,12 +107,12 @@ // Set up alias sets. op->walk([&](Operation *op) { for (Value v : op->getResults()) - if (v.getType().isa()) + if (llvm::isa(v.getType())) createAliasInfoEntry(v); for (Region &r : op->getRegions()) for (Block &b : r.getBlocks()) for (auto bbArg : b.getArguments()) - if (bbArg.getType().isa()) + if (llvm::isa(bbArg.getType())) createAliasInfoEntry(bbArg); }); @@ -121,7 +121,7 @@ if (!options.isOpAllowed(bufferizableOp)) return WalkResult::skip(); for (OpOperand &opOperand : bufferizableOp->getOpOperands()) - if (opOperand.get().getType().isa()) + if (llvm::isa(opOperand.get().getType())) if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) bufferizeInPlace(opOperand); return WalkResult::advance(); @@ -187,13 +187,13 @@ for (OpOperand &returnValOperand : returnOp->getOpOperands()) { Value returnVal = returnValOperand.get(); // Skip non-tensor values. - if (!returnVal.getType().isa()) + if (!llvm::isa(returnVal.getType())) continue; // Add all aliases of the returned value. But only the ones that are in // the same block. applyOnAliases(returnVal, [&](Value v) { - if (auto bbArg = v.dyn_cast()) { + if (auto bbArg = llvm::dyn_cast(v)) { if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp()) yieldedTensors.insert(bbArg); return; @@ -217,7 +217,7 @@ // Check all tensor OpResults. for (OpResult opResult : op->getOpResults()) { - if (!opResult.getType().isa()) + if (!llvm::isa(opResult.getType())) continue; // If there is no preceding definition, the tensor contents are @@ -259,7 +259,7 @@ return bufferizableOp.isWritable(value, *this); // Query BufferizableOpInterface to see if the BlockArgument is writable. - if (auto bbArg = value.dyn_cast()) + if (auto bbArg = llvm::dyn_cast(value)) if (auto bufferizableOp = getOptions().dynCastBufferizableOp(bbArg.getOwner()->getParentOp())) return bufferizableOp.isWritable(bbArg, *this); @@ -431,12 +431,12 @@ id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]"; readingOp->setAttr(readAttr, b.getUnitAttr()); - if (auto opResult = definition.dyn_cast()) { + if (auto opResult = llvm::dyn_cast(definition)) { std::string defAttr = id + "[DEF: result " + std::to_string(opResult.getResultNumber()) + "]"; opResult.getDefiningOp()->setAttr(defAttr, b.getUnitAttr()); } else { - auto bbArg = definition.cast(); + auto bbArg = llvm::cast(definition); std::string defAttr = id + "[DEF: bbArg " + std::to_string(bbArg.getArgNumber()) + "]"; bbArg.getOwner()->getParentOp()->setAttr(defAttr, b.getUnitAttr()); @@ -581,7 +581,7 @@ continue; } } else { - auto bbArg = definition.cast(); + auto bbArg = llvm::cast(definition); Block *block = bbArg.getOwner(); if (!block->findAncestorOpInBlock(*conflictingWritingOp)) { LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg " @@ -715,12 +715,12 @@ static int64_t counter = 0; OpBuilder b(value.getContext()); std::string id = "W_" + std::to_string(counter++); - if (auto opResult = value.dyn_cast()) { + if (auto opResult = llvm::dyn_cast(value)) { std::string attr = id + "[NOT-WRITABLE: result " + std::to_string(opResult.getResultNumber()) + "]"; opResult.getDefiningOp()->setAttr(attr, b.getUnitAttr()); } else { - auto bbArg = value.cast(); + auto bbArg = llvm::cast(value); std::string attr = id + "[NOT-WRITABLE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]"; bbArg.getOwner()->getParentOp()->setAttr(attr, b.getUnitAttr()); @@ -812,7 +812,7 @@ OneShotAnalysisState::analyzeSingleOp(Operation *op, const DominanceInfo &domInfo) { for (OpOperand &opOperand : op->getOpOperands()) - if (opOperand.get().getType().isa()) + if (llvm::isa(opOperand.get().getType())) if (failed(bufferizableInPlaceAnalysisImpl(opOperand, *this, domInfo))) return failure(); return success(); @@ -831,7 +831,7 @@ for (Operation *op : ops) { if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) { for (OpResult opResult : op->getOpResults()) { - if (!opResult.getType().isa()) + if (!llvm::isa(opResult.getType())) continue; AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult); if (aliases.getNumAliases() == 0) @@ -958,7 +958,7 @@ } for (OpOperand &opOperand : op->getOpOperands()) { - if (opOperand.get().getType().isa()) { + if (llvm::isa(opOperand.get().getType())) { if (wouldCreateReadAfterWriteInterference( opOperand, domInfo, state, /*checkConsistencyOnly=*/true)) { @@ -984,7 +984,7 @@ // Add __inplace_operands_attr__. op->walk([&](Operation *op) { for (OpOperand &opOperand : op->getOpOperands()) - if (opOperand.get().getType().isa()) + if (llvm::isa(opOperand.get().getType())) setInPlaceOpOperand(opOperand, state.isInPlace(opOperand)); }); } @@ -1031,12 +1031,12 @@ for (OpOperand &returnValOperand : returnOp->getOpOperands()) { Value returnVal = returnValOperand.get(); // Skip non-tensor values. - if (!returnVal.getType().isa()) + if (!llvm::isa(returnVal.getType())) continue; bool foundEquivValue = false; state.applyOnEquivalenceClass(returnVal, [&](Value equivVal) { - if (auto bbArg = equivVal.dyn_cast()) { + if (auto bbArg = llvm::dyn_cast(equivVal)) { Operation *definingOp = bbArg.getOwner()->getParentOp(); if (definingOp->isProperAncestor(returnOp)) foundEquivValue = true; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -109,9 +109,9 @@ SmallVector equivBbArgs; if (op->hasAttr(kEquivalentArgsAttr)) { - auto attr = op->getAttr(kEquivalentArgsAttr).cast(); + auto attr = llvm::cast(op->getAttr(kEquivalentArgsAttr)); equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) { - return a.cast().getValue().getSExtValue(); + return llvm::cast(a).getValue().getSExtValue(); })); } else { equivBbArgs.append(op->getNumOperands(), -1); @@ -132,10 +132,10 @@ // return value may alias with any tensor bbArg. FunctionType type = funcOp.getFunctionType(); for (const auto &inputIt : llvm::enumerate(type.getInputs())) { - if (!inputIt.value().isa()) + if (!llvm::isa(inputIt.value())) continue; for (const auto &resultIt : llvm::enumerate(type.getResults())) { - if (!resultIt.value().isa()) + if (!llvm::isa(resultIt.value())) continue; int64_t returnIdx = resultIt.index(); int64_t bbArgIdx = inputIt.index(); @@ -150,9 +150,9 @@ assert(returnOp && "expected func with single return op"); for (OpOperand &returnVal : returnOp->getOpOperands()) - if (returnVal.get().getType().isa()) + if (llvm::isa(returnVal.get().getType())) for (BlockArgument bbArg : funcOp.getArguments()) - if (bbArg.getType().isa()) { + if (llvm::isa(bbArg.getType())) { int64_t returnIdx = returnVal.getOperandNumber(); int64_t bbArgIdx = bbArg.getArgNumber(); if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) { @@ -193,7 +193,7 @@ for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e; ++idx) { // Skip non-tensor arguments. - if (!funcOp.getFunctionType().getInput(idx).isa()) + if (!llvm::isa(funcOp.getFunctionType().getInput(idx))) continue; bool isRead; bool isWritten; diff --git a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp --- a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp @@ -34,7 +34,7 @@ Location loc) { if (complex::ConstantOp::isBuildableWith(value, type)) { return builder.create(loc, type, - value.cast()); + llvm::cast(value)); } return arith::ConstantOp::materialize(builder, value, type, loc); } @@ -46,16 +46,16 @@ ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::llvm::APFloat real, ::llvm::APFloat imag, ::mlir::Type type) { - if (!type.isa()) + if (!llvm::isa(type)) return emitError() << "complex attribute must be a complex type."; - Type elementType = type.cast().getElementType(); - if (!elementType.isa()) + Type elementType = llvm::cast(type).getElementType(); + if (!llvm::isa(elementType)) return emitError() << "element type of the complex attribute must be float like type."; const auto &typeFloatSemantics = - elementType.cast().getFloatSemantics(); + llvm::cast(elementType).getFloatSemantics(); if (&real.getSemantics() != &typeFloatSemantics) return emitError() << "type doesn't match the type implied by its `real` value"; @@ -67,7 +67,7 @@ } void complex::NumberAttr::print(AsmPrinter &printer) const { - printer << "<:" << getType().cast().getElementType() << " " + printer << "<:" << llvm::cast(getType()).getElementType() << " " << getReal() << ", " << getImag() << ">"; } diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -27,18 +27,18 @@ } bool ConstantOp::isBuildableWith(Attribute value, Type type) { - if (auto arrAttr = value.dyn_cast()) { - auto complexTy = type.dyn_cast(); + if (auto arrAttr = llvm::dyn_cast(value)) { + auto complexTy = llvm::dyn_cast(type); if (!complexTy || arrAttr.size() != 2) return false; auto complexEltTy = complexTy.getElementType(); - if (auto fre = arrAttr[0].dyn_cast()) { - auto im = arrAttr[1].dyn_cast(); + if (auto fre = llvm::dyn_cast(arrAttr[0])) { + auto im = llvm::dyn_cast(arrAttr[1]); return im && fre.getType() == complexEltTy && im.getType() == complexEltTy; } - if (auto ire = arrAttr[0].dyn_cast()) { - auto im = arrAttr[1].dyn_cast(); + if (auto ire = llvm::dyn_cast(arrAttr[0])) { + auto im = llvm::dyn_cast(arrAttr[1]); return im && ire.getType() == complexEltTy && im.getType() == complexEltTy; } @@ -55,8 +55,8 @@ } auto complexEltTy = getType().getElementType(); - auto re = arrayAttr[0].dyn_cast(); - auto im = arrayAttr[1].dyn_cast(); + auto re = llvm::dyn_cast(arrayAttr[0]); + auto im = llvm::dyn_cast(arrayAttr[1]); if (!re || !im) return emitOpError("requires attribute's elements to be float attributes"); if (complexEltTy != re.getType() || complexEltTy != im.getType()) { @@ -129,8 +129,8 @@ // complex.add(a, complex.constant<0.0, 0.0>) -> a if (auto constantOp = getRhs().getDefiningOp()) { auto arrayAttr = constantOp.getValue(); - if (arrayAttr[0].cast().getValue().isZero() && - arrayAttr[1].cast().getValue().isZero()) { + if (llvm::cast(arrayAttr[0]).getValue().isZero() && + llvm::cast(arrayAttr[1]).getValue().isZero()) { return getLhs(); } } @@ -151,8 +151,8 @@ // complex.sub(a, complex.constant<0.0, 0.0>) -> a if (auto constantOp = getRhs().getDefiningOp()) { auto arrayAttr = constantOp.getValue(); - if (arrayAttr[0].cast().getValue().isZero() && - arrayAttr[1].cast().getValue().isZero()) { + if (llvm::cast(arrayAttr[0]).getValue().isZero() && + llvm::cast(arrayAttr[1]).getValue().isZero()) { return getLhs(); } } diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -125,7 +125,7 @@ // Otherwise, we need to remap any argument operands. for (Value operand : operands) { - BlockArgument argOperand = operand.dyn_cast(); + BlockArgument argOperand = llvm::dyn_cast(operand); if (argOperand && argOperand.getOwner() == successor) argStorage.push_back(successorOperands[argOperand.getArgNumber()]); else @@ -442,7 +442,8 @@ } Block *CondBranchOp::getSuccessorForOperands(ArrayRef operands) { - if (IntegerAttr condAttr = operands.front().dyn_cast_or_null()) + if (IntegerAttr condAttr = + llvm::dyn_cast_or_null(operands.front())) return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest(); return nullptr; } @@ -601,7 +602,7 @@ return getDefaultDestination(); SuccessorRange caseDests = getCaseDestinations(); - if (auto value = operands.front().dyn_cast_or_null()) { + if (auto value = llvm::dyn_cast_or_null(operands.front())) { for (const auto &it : llvm::enumerate(caseValues->getValues())) if (it.value() == value.getValue()) return caseDests[it.index()]; diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp --- a/mlir/lib/Dialect/DLTI/DLTI.cpp +++ b/mlir/lib/Dialect/DLTI/DLTI.cpp @@ -215,7 +215,7 @@ typeSample.getContext()->getLoadedDialect() && "unexpected data layout entry for built-in type"); - auto interface = typeSample.cast(); + auto interface = llvm::cast(typeSample); if (!interface.areCompatible(entriesForType.lookup(kvp.first), kvp.second)) return failure(); @@ -250,7 +250,7 @@ // Only combine with attributes of the same kind. // TODO: reconsider this when the need arises. if (llvm::any_of(specs, [](DataLayoutSpecInterface spec) { - return !spec.isa(); + return !llvm::isa(spec); })) return {}; @@ -334,7 +334,7 @@ Location loc) const final { StringRef entryName = entry.getKey().get().strref(); if (entryName == DLTIDialect::kDataLayoutEndiannessKey) { - auto value = entry.getValue().dyn_cast(); + auto value = llvm::dyn_cast(entry.getValue()); if (value && (value.getValue() == DLTIDialect::kDataLayoutEndiannessBig || value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle)) @@ -383,7 +383,7 @@ LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { if (attr.getName() == DLTIDialect::kDataLayoutAttrName) { - if (!attr.getValue().isa()) { + if (!llvm::isa(attr.getValue())) { return op->emitError() << "'" << DLTIDialect::kDataLayoutAttrName << "' is expected to be a #dlti.dl_spec attribute"; } diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -73,10 +73,10 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { Type input = inputs.front(), output = outputs.front(); - return ((input.isa()) && - (output.isa())); + return ((llvm::isa(input)) && + (llvm::isa(output))); } //===----------------------------------------------------------------------===// @@ -90,8 +90,8 @@ if (std::optional argsAttr = getArgs()) { for (Attribute arg : *argsAttr) { - auto intAttr = arg.dyn_cast(); - if (intAttr && intAttr.getType().isa()) { + auto intAttr = llvm::dyn_cast(arg); + if (intAttr && llvm::isa(intAttr.getType())) { int64_t index = intAttr.getInt(); // Args with elements of type index must be in range // [0..operands.size). @@ -99,7 +99,8 @@ return emitOpError("index argument is out of range"); // Args with elements of type ArrayAttr must have a type. - } else if (arg.isa() /*&& arg.getType().isa()*/) { + } else if (llvm::isa( + arg) /*&& arg.getType().isa()*/) { // FIXME: Array attributes never have types return emitOpError("array argument has no type"); } @@ -108,7 +109,7 @@ if (std::optional templateArgsAttr = getTemplateArgs()) { for (Attribute tArg : *templateArgsAttr) { - if (!tArg.isa()) + if (!llvm::isa(tArg)) return emitOpError("template argument has invalid type"); } } @@ -122,17 +123,17 @@ /// The constant op requires that the attribute's type matches the return type. LogicalResult emitc::ConstantOp::verify() { - if (getValueAttr().isa()) + if (llvm::isa(getValueAttr())) return success(); // Value must not be empty - StringAttr strAttr = getValueAttr().dyn_cast(); + StringAttr strAttr = llvm::dyn_cast(getValueAttr()); if (strAttr && strAttr.getValue().empty()) return emitOpError() << "value must not be empty"; auto value = cast(getValueAttr()); Type type = getType(); - if (!value.getType().isa() && type != value.getType()) + if (!llvm::isa(value.getType()) && type != value.getType()) return emitOpError() << "requires attribute's type (" << value.getType() << ") to match op's return type (" << type << ")"; return success(); @@ -183,12 +184,12 @@ /// The variable op requires that the attribute's type matches the return type. LogicalResult emitc::VariableOp::verify() { - if (getValueAttr().isa()) + if (llvm::isa(getValueAttr())) return success(); auto value = cast(getValueAttr()); Type type = getType(); - if (!value.getType().isa() && type != value.getType()) + if (!llvm::isa(value.getType()) && type != value.getType()) return emitOpError() << "requires attribute's type (" << value.getType() << ") to match op's return type (" << type << ")"; return success(); diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp --- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp +++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp @@ -112,7 +112,7 @@ Type type, Location loc) { if (ConstantOp::isBuildableWith(value, type)) return builder.create(loc, type, - value.cast()); + llvm::cast(value)); return nullptr; } @@ -209,7 +209,7 @@ } bool ConstantOp::isBuildableWith(Attribute value, Type type) { - return value.isa() && type.isa(); + return llvm::isa(value) && llvm::isa(type); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -220,7 +220,7 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { - if (!attr.getValue().isa() || + if (!llvm::isa(attr.getValue()) || attr.getName() != getContainerModuleAttrName()) return success(); @@ -368,14 +368,14 @@ ArrayRef attributions, gpu::AddressSpace memorySpace) { for (Value v : attributions) { - auto type = v.getType().dyn_cast(); + auto type = llvm::dyn_cast(v.getType()); if (!type) return op->emitOpError() << "expected memref type in attribution"; // We can only verify the address space if it hasn't already been lowered // from the AddressSpaceAttr to a target-specific numeric value. auto addressSpace = - type.getMemorySpace().dyn_cast_or_null(); + llvm::dyn_cast_or_null(type.getMemorySpace()); if (!addressSpace) continue; if (addressSpace.getValue() != memorySpace) @@ -395,7 +395,7 @@ return (opName != gpu::AllReduceOperation::AND && opName != gpu::AllReduceOperation::OR && opName != gpu::AllReduceOperation::XOR) || - resType.isa(); + llvm::isa(resType); } LogicalResult gpu::AllReduceOp::verifyRegions() { @@ -1186,7 +1186,7 @@ size_t attributionIndex = pair.index(); DictionaryAttr attrs; if (attributes && attributionIndex < attributes.size()) - attrs = attributes[attributionIndex].cast(); + attrs = llvm::cast(attributes[attributionIndex]); if (attrs) p.printOptionalAttrDict(attrs.getValue()); }); @@ -1221,10 +1221,10 @@ static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName) { - auto allAttrs = op->getAttr(attrName).dyn_cast_or_null(); + auto allAttrs = llvm::dyn_cast_or_null(op->getAttr(attrName)); if (!allAttrs || index >= allAttrs.size()) return DictionaryAttr(); - return allAttrs[index].cast(); + return llvm::cast(allAttrs[index]); } DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) { @@ -1238,7 +1238,7 @@ static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName) { MLIRContext *ctx = op.getContext(); - auto allAttrs = op->getAttr(attrName).dyn_cast_or_null(); + auto allAttrs = llvm::dyn_cast_or_null(op->getAttr(attrName)); SmallVector elements; if (allAttrs) elements.append(allAttrs.begin(), allAttrs.end()); @@ -1379,7 +1379,7 @@ auto maybeAttr = op->getAttr(attrName); if (!maybeAttr) return success(); - auto array = maybeAttr.dyn_cast(); + auto array = llvm::dyn_cast(maybeAttr); if (!array) return op.emitOpError(attrName + " must be a dense i32 array"); if (array.size() != 3) @@ -1536,9 +1536,9 @@ LogicalResult SubgroupMmaLoadMatrixOp::verify() { auto srcType = getSrcMemref().getType(); auto resType = getRes().getType(); - auto resMatrixType = resType.cast(); + auto resMatrixType = llvm::cast(resType); auto operand = resMatrixType.getOperand(); - auto srcMemrefType = srcType.cast(); + auto srcMemrefType = llvm::cast(srcType); if (!isLastMemrefDimUnitStride(srcMemrefType)) return emitError( @@ -1558,8 +1558,8 @@ LogicalResult SubgroupMmaStoreMatrixOp::verify() { auto srcType = getSrc().getType(); auto dstType = getDstMemref().getType(); - auto srcMatrixType = srcType.cast(); - auto dstMemrefType = dstType.cast(); + auto srcMatrixType = llvm::cast(srcType); + auto dstMemrefType = llvm::cast(dstType); if (!isLastMemrefDimUnitStride(dstMemrefType)) return emitError( @@ -1579,9 +1579,9 @@ LogicalResult SubgroupMmaComputeOp::verify() { enum OperandMap { A, B, C }; SmallVector opTypes; - opTypes.push_back(getOpA().getType().cast()); - opTypes.push_back(getOpB().getType().cast()); - opTypes.push_back(getOpC().getType().cast()); + opTypes.push_back(llvm::cast(getOpA().getType())); + opTypes.push_back(llvm::cast(getOpB().getType())); + opTypes.push_back(llvm::cast(getOpC().getType())); if (!opTypes[A].getOperand().equals("AOp") || !opTypes[B].getOperand().equals("BOp") || @@ -1688,7 +1688,7 @@ //===----------------------------------------------------------------------===// LogicalResult AllocOp::verify() { - auto memRefType = getMemref().getType().cast(); + auto memRefType = llvm::cast(getMemref().getType()); if (static_cast(getDynamicSizes().size()) != memRefType.getNumDynamicDims()) @@ -1719,7 +1719,7 @@ if (!index) return failure(); - auto memrefType = dimOp.getSource().getType().dyn_cast(); + auto memrefType = llvm::dyn_cast(dimOp.getSource().getType()); if (!memrefType || !memrefType.isDynamicDim(index.value())) return failure(); diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp @@ -214,7 +214,7 @@ /// Returns an accumulator factory that creates an op specified by opName. AccumulatorFactory getFactory(gpu::AllReduceOperation opName) { - bool isFloatingPoint = valueType.isa(); + bool isFloatingPoint = llvm::isa(valueType); switch (opName) { case gpu::AllReduceOperation::ADD: return isFloatingPoint ? getFactory() diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -338,7 +338,7 @@ if (!resultAttr) return failure(); - dataLayoutSpec = resultAttr.dyn_cast(); + dataLayoutSpec = llvm::dyn_cast(resultAttr); if (!dataLayoutSpec) return failure(); } @@ -410,7 +410,8 @@ SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) { for (SymbolTable::SymbolUse symbolUse : *symbolUses) { StringRef symbolName = - symbolUse.getSymbolRef().cast().getValue(); + llvm::cast(symbolUse.getSymbolRef()) + .getValue(); if (symbolTable.lookup(symbolName)) continue; diff --git a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp --- a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp @@ -30,7 +30,7 @@ /// single-iteration loops. Maps the innermost loops to thread dimensions, in /// reverse order to enable access coalescing in the innermost loop. static void insertCopyLoops(ImplicitLocOpBuilder &b, Value from, Value to) { - auto memRefType = from.getType().cast(); + auto memRefType = llvm::cast(from.getType()); auto rank = memRefType.getRank(); SmallVector lbs, ubs, steps; @@ -121,8 +121,8 @@ /// pointed to by "from". In case a smaller block would be sufficient, the /// caller can create a subview of the memref and promote it instead. static void insertCopies(Region ®ion, Location loc, Value from, Value to) { - auto fromType = from.getType().cast(); - auto toType = to.getType().cast(); + auto fromType = llvm::cast(from.getType()); + auto toType = llvm::cast(to.getType()); (void)fromType; (void)toType; assert(fromType.getShape() == toType.getShape()); @@ -143,7 +143,7 @@ /// copies will be inserted in the beginning and in the end of the function. void mlir::promoteToWorkgroupMemory(GPUFuncOp op, unsigned arg) { Value value = op.getArgument(arg); - auto type = value.getType().dyn_cast(); + auto type = llvm::dyn_cast(value.getType()); assert(type && type.hasStaticShape() && "can only promote memrefs"); // Get the type of the buffer in the workgroup memory. diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp --- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp +++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp @@ -39,7 +39,8 @@ // Materialize integer attributes as `index`. if (auto indexValue = dyn_cast(value)) { - if (!indexValue.getType().isa() || !type.isa()) + if (!llvm::isa(indexValue.getType()) || + !llvm::isa(type)) return nullptr; assert(indexValue.getValue().getBitWidth() == IndexType::kInternalStorageBitWidth); @@ -399,7 +400,8 @@ //===----------------------------------------------------------------------===// bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) { - return lhsTypes.front().isa() != rhsTypes.front().isa(); + return llvm::isa(lhsTypes.front()) != + llvm::isa(rhsTypes.front()); } //===----------------------------------------------------------------------===// @@ -407,7 +409,8 @@ //===----------------------------------------------------------------------===// bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) { - return lhsTypes.front().isa() != rhsTypes.front().isa(); + return llvm::isa(lhsTypes.front()) != + llvm::isa(rhsTypes.front()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -179,7 +179,7 @@ //===----------------------------------------------------------------------===// void AllocaOp::print(OpAsmPrinter &p) { - Type elemTy = getType().cast().getElementType(); + Type elemTy = llvm::cast(getType()).getElementType(); if (!elemTy) elemTy = *getElemType(); @@ -220,7 +220,7 @@ std::optional alignmentAttr = result.attributes.getNamed("alignment"); if (alignmentAttr.has_value()) { - auto alignmentInt = alignmentAttr->getValue().dyn_cast(); + auto alignmentInt = llvm::dyn_cast(alignmentAttr->getValue()); if (!alignmentInt) return parser.emitError(parser.getNameLoc(), "expected integer alignment"); @@ -229,7 +229,7 @@ } // Extract the result type from the trailing function type. - auto funcType = type.dyn_cast(); + auto funcType = llvm::dyn_cast(type); if (!funcType || funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) return parser.emitError( @@ -240,7 +240,7 @@ return failure(); Type resultType = funcType.getResult(0); - if (auto ptrResultType = resultType.dyn_cast()) { + if (auto ptrResultType = llvm::dyn_cast(resultType)) { if (ptrResultType.isOpaque()) result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType)); } @@ -266,7 +266,7 @@ } LogicalResult AllocaOp::verify() { - return verifyOpaquePtr(getOperation(), getType().cast(), + return verifyOpaquePtr(getOperation(), llvm::cast(getType()), getElemType()); } @@ -410,7 +410,7 @@ size_t index = 0; llvm::interleave( - llvm::zip(caseValues.cast(), caseDestinations), + llvm::zip(llvm::cast(caseValues), caseDestinations), [&](auto i) { p << " "; p << std::get<0>(i).getLimitedValue(); @@ -457,11 +457,11 @@ /// Returns the elemental type of any LLVM-compatible vector type or self. static Type extractVectorElementType(Type type) { - if (auto vectorType = type.dyn_cast()) + if (auto vectorType = llvm::dyn_cast(type)) return vectorType.getElementType(); - if (auto scalableVectorType = type.dyn_cast()) + if (auto scalableVectorType = llvm::dyn_cast(type)) return scalableVectorType.getElementType(); - if (auto fixedVectorType = type.dyn_cast()) + if (auto fixedVectorType = llvm::dyn_cast(type)) return fixedVectorType.getElementType(); return type; } @@ -470,7 +470,7 @@ Value basePtr, ArrayRef indices, bool inbounds, ArrayRef attributes) { auto ptrType = - extractVectorElementType(basePtr.getType()).cast(); + llvm::cast(extractVectorElementType(basePtr.getType())); assert(!ptrType.isOpaque() && "expected non-opaque pointer, provide elementType explicitly when " "opaque pointers are used"); @@ -543,8 +543,7 @@ result.addAttribute(getInboundsAttrName(result.name), builder.getUnitAttr()); } - if (extractVectorElementType(basePtr.getType()) - .cast() + if (llvm::cast(extractVectorElementType(basePtr.getType())) .isOpaque()) result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType)); result.addOperands(basePtr); @@ -695,7 +694,7 @@ LogicalResult LLVM::GEPOp::verify() { if (failed(verifyOpaquePtr( getOperation(), - extractVectorElementType(getType()).cast(), + llvm::cast(extractVectorElementType(getType())), getElemType()))) return failure(); @@ -716,8 +715,8 @@ if (std::optional elemType = getElemType()) return *elemType; - return extractVectorElementType(getBase().getType()) - .cast() + return llvm::cast( + extractVectorElementType(getBase().getType())) .getElementType(); } @@ -729,16 +728,16 @@ /// integer and float types with limited bit width are supported. Additionally, /// depending on the operation pointers may be supported as well. static bool isTypeCompatibleWithAtomicOp(Type type, bool isPointerTypeAllowed) { - if (type.isa()) + if (llvm::isa(type)) return isPointerTypeAllowed; std::optional bitWidth; - if (auto floatType = type.dyn_cast()) { + if (auto floatType = llvm::dyn_cast(type)) { if (!isCompatibleFloatingPointType(type)) return false; bitWidth = floatType.getWidth(); } - if (auto integerType = type.dyn_cast()) + if (auto integerType = llvm::dyn_cast(type)) bitWidth = integerType.getWidth(); // The type is neither an integer, float, or pointer type. if (!bitWidth) @@ -777,7 +776,7 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Value addr, unsigned alignment, bool isVolatile, bool isNonTemporal) { - auto type = addr.getType().cast().getElementType(); + auto type = llvm::cast(addr.getType()).getElementType(); assert(type && "must provide explicit element type to the constructor " "when the pointer type is opaque"); build(builder, state, type, addr, alignment, isVolatile, isNonTemporal); @@ -801,7 +800,7 @@ // std::nullopt if the given type is not the pointer type. static std::optional getLoadStoreElementType(OpAsmParser &parser, Type type, SMLoc trailingTypeLoc) { - auto llvmTy = type.dyn_cast(); + auto llvmTy = llvm::dyn_cast(type); if (!llvmTy) { parser.emitError(trailingTypeLoc, "expected LLVM pointer type"); return std::nullopt; @@ -919,7 +918,7 @@ ValueRange args) { SmallVector results; Type resultType = func.getFunctionType().getReturnType(); - if (!resultType.isa()) + if (!llvm::isa(resultType)) results.push_back(resultType); build(builder, state, results, SymbolRefAttr::get(func), args, nullptr, nullptr); @@ -964,7 +963,7 @@ if (!getNumOperands()) return emitOpError( "must have either a `callee` attribute or at least an operand"); - auto ptrType = getOperand(0).getType().dyn_cast(); + auto ptrType = llvm::dyn_cast(getOperand(0).getType()); if (!ptrType) return emitOpError("indirect call expects a pointer as callee: ") << getOperand(0).getType(); @@ -988,7 +987,7 @@ fnType = fn.getFunctionType(); } - LLVMFunctionType funcType = fnType.dyn_cast(); + LLVMFunctionType funcType = llvm::dyn_cast(fnType); if (!funcType) return emitOpError("callee does not have a functional type: ") << fnType; @@ -1023,11 +1022,11 @@ << " != " << funcType.getParamType(i); if (getNumResults() == 0 && - !funcType.getReturnType().isa()) + !llvm::isa(funcType.getReturnType())) return emitOpError() << "expected function call to produce a value"; if (getNumResults() != 0 && - funcType.getReturnType().isa()) + llvm::isa(funcType.getReturnType())) return emitOpError() << "calling function with void result must not produce values"; @@ -1083,7 +1082,7 @@ return parser.emitError(trailingTypesLoc, "expected indirect call to have 2 trailing types"); - auto funcType = types.pop_back_val().dyn_cast(); + auto funcType = llvm::dyn_cast(types.pop_back_val()); if (!funcType) return parser.emitError(trailingTypesLoc, "expected trailing function type"); @@ -1091,7 +1090,7 @@ return parser.emitError(trailingTypesLoc, "expected function with 0 or 1 result"); if (funcType.getNumResults() == 1 && - funcType.getResult(0).isa()) + llvm::isa(funcType.getResult(0))) return parser.emitError(trailingTypesLoc, "expected a non-void result type"); @@ -1292,7 +1291,7 @@ for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) { value = getOperand(idx); - bool isFilter = value.getType().isa(); + bool isFilter = llvm::isa(value.getType()); if (isFilter) { // FIXME: Verify filter clauses when arrays are appropriately handled } else { @@ -1324,7 +1323,7 @@ for (auto value : getOperands()) { // Similar to llvm - if clause is an array type then it is filter // clause else catch clause - bool isArrayTy = value.getType().isa(); + bool isArrayTy = llvm::isa(value.getType()); p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : " << value.getType() << ") "; } @@ -1383,13 +1382,13 @@ // structures. Check the position index before accessing, it is supposed to // be in bounds. for (int64_t idx : position) { - if (auto arrayType = llvmType.dyn_cast()) { + if (auto arrayType = llvm::dyn_cast(llvmType)) { if (idx < 0 || static_cast(idx) >= arrayType.getNumElements()) { emitError("position out of bounds: ") << idx; return {}; } llvmType = arrayType.getElementType(); - } else if (auto structType = llvmType.dyn_cast()) { + } else if (auto structType = llvm::dyn_cast(llvmType)) { if (idx < 0 || static_cast(idx) >= structType.getBody().size()) { emitError("position out of bounds: ") << idx; @@ -1409,10 +1408,10 @@ static Type getInsertExtractValueElementType(Type llvmType, ArrayRef position) { for (int64_t idx : position) { - if (auto structType = llvmType.dyn_cast()) + if (auto structType = llvm::dyn_cast(llvmType)) llvmType = structType.getBody()[idx]; else - llvmType = llvmType.cast().getElementType(); + llvmType = llvm::cast(llvmType).getElementType(); } return llvmType; } @@ -1519,7 +1518,7 @@ return success(); Type expectedType = parent.getFunctionType().getReturnType(); - if (expectedType.isa()) { + if (llvm::isa(expectedType)) { if (!getArg()) return success(); InFlightDiagnostic diag = emitOpError("expected no operands"); @@ -1527,7 +1526,7 @@ return diag; } if (!getArg()) { - if (expectedType.isa()) + if (llvm::isa(expectedType)) return success(); InFlightDiagnostic diag = emitOpError("expected 1 operand"); diag.attachNote(parent->getLoc()) << "when returning from function"; @@ -1664,7 +1663,7 @@ getVisibility_AttrName()}); // Print the trailing type unless it's a string global. - if (getValueOrNull().dyn_cast_or_null()) + if (llvm::dyn_cast_or_null(getValueOrNull())) return; p << " : " << getType(); @@ -1779,7 +1778,7 @@ Region &initRegion = *result.addRegion(); if (types.empty()) { - if (auto strAttr = value.dyn_cast_or_null()) { + if (auto strAttr = llvm::dyn_cast_or_null(value)) { MLIRContext *context = parser.getContext(); auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8), strAttr.getValue().size()); @@ -1802,15 +1801,15 @@ } static bool isZeroAttribute(Attribute value) { - if (auto intValue = value.dyn_cast()) + if (auto intValue = llvm::dyn_cast(value)) return intValue.getValue().isZero(); - if (auto fpValue = value.dyn_cast()) + if (auto fpValue = llvm::dyn_cast(value)) return fpValue.getValue().isZero(); - if (auto splatValue = value.dyn_cast()) + if (auto splatValue = llvm::dyn_cast(value)) return isZeroAttribute(splatValue.getSplatValue()); - if (auto elementsValue = value.dyn_cast()) + if (auto elementsValue = llvm::dyn_cast(value)) return llvm::all_of(elementsValue.getValues(), isZeroAttribute); - if (auto arrayValue = value.dyn_cast()) + if (auto arrayValue = llvm::dyn_cast(value)) return llvm::all_of(arrayValue.getValue(), isZeroAttribute); return false; } @@ -1822,10 +1821,10 @@ if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp())) return emitOpError("must appear at the module level"); - if (auto strAttr = getValueOrNull().dyn_cast_or_null()) { - auto type = getType().dyn_cast(); + if (auto strAttr = llvm::dyn_cast_or_null(getValueOrNull())) { + auto type = llvm::dyn_cast(getType()); IntegerType elementType = - type ? type.getElementType().dyn_cast() : nullptr; + type ? llvm::dyn_cast(type.getElementType()) : nullptr; if (!elementType || elementType.getWidth() != 8 || type.getNumElements() != strAttr.getValue().size()) return emitOpError( @@ -1844,7 +1843,7 @@ } if (getLinkage() == Linkage::Appending) { - if (!getType().isa()) { + if (!llvm::isa(getType())) { return emitOpError() << "expected array type for '" << stringifyLinkage(Linkage::Appending) << "' linkage"; @@ -1892,7 +1891,7 @@ LogicalResult GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { for (Attribute ctor : getCtors()) { - if (failed(verifySymbolAttrUse(ctor.cast(), *this, + if (failed(verifySymbolAttrUse(llvm::cast(ctor), *this, symbolTable))) return failure(); } @@ -1913,7 +1912,7 @@ LogicalResult GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { for (Attribute dtor : getDtors()) { - if (failed(verifySymbolAttrUse(dtor.cast(), *this, + if (failed(verifySymbolAttrUse(llvm::cast(dtor), *this, symbolTable))) return failure(); } @@ -2012,7 +2011,7 @@ if (argAttrs.empty()) return; - assert(type.cast().getNumParams() == argAttrs.size() && + assert(llvm::cast(type).getNumParams() == argAttrs.size() && "expected as many argument attribute lists as arguments"); function_interface_impl::addArgAndResultAttrs( builder, result, argAttrs, /*resultAttrs=*/std::nullopt, @@ -2143,7 +2142,7 @@ argTypes.push_back(fnType.getParamType(i)); Type returnType = fnType.getReturnType(); - if (!returnType.isa()) + if (!llvm::isa(returnType)) resTypes.push_back(returnType); function_interface_impl::printFunctionSignature(p, *this, argTypes, @@ -2251,8 +2250,8 @@ //===----------------------------------------------------------------------===// LogicalResult LLVM::ConstantOp::verify() { - if (StringAttr sAttr = getValue().dyn_cast()) { - auto arrayType = getType().dyn_cast(); + if (StringAttr sAttr = llvm::dyn_cast(getValue())) { + auto arrayType = llvm::dyn_cast(getType()); if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() || !arrayType.getElementType().isInteger(8)) { return emitOpError() << "expected array type of " @@ -2261,35 +2260,35 @@ } return success(); } - if (auto structType = getType().dyn_cast()) { + if (auto structType = llvm::dyn_cast(getType())) { if (structType.getBody().size() != 2 || structType.getBody()[0] != structType.getBody()[1]) { return emitError() << "expected struct type with two elements of the " "same type, the type of a complex constant"; } - auto arrayAttr = getValue().dyn_cast(); + auto arrayAttr = llvm::dyn_cast(getValue()); if (!arrayAttr || arrayAttr.size() != 2) { return emitOpError() << "expected array attribute with two elements, " "representing a complex constant"; } - auto re = arrayAttr[0].dyn_cast(); - auto im = arrayAttr[1].dyn_cast(); + auto re = llvm::dyn_cast(arrayAttr[0]); + auto im = llvm::dyn_cast(arrayAttr[1]); if (!re || !im || re.getType() != im.getType()) { return emitOpError() << "expected array attribute with two elements of the same type"; } Type elementType = structType.getBody()[0]; - if (!elementType - .isa()) { + if (!llvm::isa( + elementType)) { return emitError() << "expected struct element types to be floating point type or " "integer type"; } return success(); } - if (!getValue().isa()) + if (!llvm::isa(getValue())) return emitOpError() << "only supports integer, float, string or elements attributes"; return success(); @@ -2314,7 +2313,7 @@ } LogicalResult AtomicRMWOp::verify() { - auto ptrType = getPtr().getType().cast(); + auto ptrType = llvm::cast(getPtr().getType()); auto valType = getVal().getType(); if (!ptrType.isOpaque() && valType != ptrType.getElementType()) return emitOpError("expected LLVM IR element type for operand #0 to " @@ -2327,7 +2326,7 @@ if (!isTypeCompatibleWithAtomicOp(valType, /*isPointerTypeAllowed=*/false)) return emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); } else { - auto intType = valType.dyn_cast(); + auto intType = llvm::dyn_cast(valType); unsigned intBitWidth = intType ? intType.getWidth() : 0; if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64) @@ -2367,7 +2366,7 @@ } LogicalResult AtomicCmpXchgOp::verify() { - auto ptrType = getPtr().getType().cast(); + auto ptrType = llvm::cast(getPtr().getType()); if (!ptrType) return emitOpError("expected LLVM IR pointer type for operand #0"); auto valType = getVal().getType(); @@ -2421,10 +2420,10 @@ } LogicalResult LLVM::BitcastOp::verify() { - auto resultType = extractVectorElementType(getResult().getType()) - .dyn_cast(); - auto sourceType = - extractVectorElementType(getArg().getType()).dyn_cast(); + auto resultType = llvm::dyn_cast( + extractVectorElementType(getResult().getType())); + auto sourceType = llvm::dyn_cast( + extractVectorElementType(getArg().getType())); // If one of the types is a pointer (or vector of pointers), then // both source and result type have to be pointers. @@ -2435,7 +2434,8 @@ return success(); auto isVector = [](Type type) { - return type.isa(); + return llvm::isa( + type); }; // Due to bitcast requiring both operands to be of the same size, it is not @@ -2480,7 +2480,7 @@ // gep %x:T, 0 -> %x if (getBase().getType() == getType() && indices.size() == 1) - if (auto integer = indices[0].dyn_cast_or_null()) + if (auto integer = llvm::dyn_cast_or_null(indices[0])) if (integer.getValue().isZero()) return getBase(); @@ -2488,7 +2488,7 @@ bool changed = false; SmallVector gepArgs; for (auto iter : llvm::enumerate(indices)) { - auto integer = iter.value().dyn_cast_or_null(); + auto integer = llvm::dyn_cast_or_null(iter.value()); // Constant indices can only be int32_t, so if integer does not fit we // are forced to keep it dynamic, despite being a constant. if (!indices.isDynamicIndex(iter.index()) || !integer || @@ -2686,7 +2686,7 @@ SmallVectorImpl &operands = tbaaGraph[tdOp.getSymNameAttr()]->operands; for (Attribute attr : tdOp.getMembers()) { - StringAttr symbolRef = attr.cast().getAttr(); + StringAttr symbolRef = llvm::cast(attr).getAttr(); if (failed(verifyReference(op, symbolRef, tdOp.getMembersAttrName()))) return failure(); @@ -2888,7 +2888,7 @@ // llvm::DataLayout constructor. if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName()) return success(); - if (auto stringAttr = attr.getValue().dyn_cast()) + if (auto stringAttr = llvm::dyn_cast(attr.getValue())) return verifyDataLayoutString( stringAttr.getValue(), [op](const Twine &message) { op->emitOpError() << message.str(); }); @@ -2909,28 +2909,28 @@ StringAttr name = paramAttr.getName(); auto checkUnitAttrType = [&]() -> LogicalResult { - if (!paramAttr.getValue().isa()) + if (!llvm::isa(paramAttr.getValue())) return op->emitError() << name << " should be a unit attribute"; return success(); }; auto checkTypeAttrType = [&]() -> LogicalResult { - if (!paramAttr.getValue().isa()) + if (!llvm::isa(paramAttr.getValue())) return op->emitError() << name << " should be a type attribute"; return success(); }; auto checkIntegerAttrType = [&]() -> LogicalResult { - if (!paramAttr.getValue().isa()) + if (!llvm::isa(paramAttr.getValue())) return op->emitError() << name << " should be an integer attribute"; return success(); }; auto checkPointerType = [&]() -> LogicalResult { - if (!paramType.isa()) + if (!llvm::isa(paramType)) return op->emitError() << name << " attribute attached to non-pointer LLVM type"; return success(); }; auto checkIntegerType = [&]() -> LogicalResult { - if (!paramType.isa()) + if (!llvm::isa(paramType)) return op->emitError() << name << " attribute attached to non-integer LLVM type"; return success(); @@ -2938,8 +2938,8 @@ auto checkPointerTypeMatches = [&]() -> LogicalResult { if (failed(checkPointerType())) return failure(); - auto ptrType = paramType.cast(); - auto typeAttr = paramAttr.getValue().cast(); + auto ptrType = llvm::cast(paramType); + auto typeAttr = llvm::cast(paramAttr.getValue()); if (!ptrType.isOpaque() && ptrType.getElementType() != typeAttr.getValue()) return op->emitError() @@ -3033,7 +3033,7 @@ // Check to see if this function has a void return with a result attribute // to it. It isn't clear what semantics we would assign to that. - if (resType.isa()) + if (llvm::isa(resType)) return op->emitError() << "cannot attach result attributes to functions " "with a void return"; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp @@ -176,7 +176,7 @@ if (auto func = dyn_cast(parentOp)) { // Use the alignment attribute set for this argument in the parent function // if it has been set. - auto blockArg = value.cast(); + auto blockArg = llvm::cast(value); if (Attribute alignAttr = func.getArgAttr( blockArg.getArgNumber(), LLVM::LLVMDialect::getAlignAttrName())) return cast(alignAttr).getValue().getLimitedValue(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp @@ -29,7 +29,7 @@ // names processed here (e.g. 'tbaa'). This verification // is redundant in some cases. if (!llvm::all_of(symbolRefs, [](Attribute attr) { - return attr && attr.isa(); + return attr && llvm::isa(attr); })) return op->emitOpError() << name << " attribute failed to satisfy constraint: " diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -24,7 +24,8 @@ /// internal functions to avoid getting a verbose `!llvm` prefix. Otherwise /// prints it as usual. static void dispatchPrint(AsmPrinter &printer, Type type) { - if (isCompatibleType(type) && !type.isa()) + if (isCompatibleType(type) && + !llvm::isa(type)) return mlir::LLVM::detail::printType(type, printer); printer.printType(type); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -130,8 +130,9 @@ //===----------------------------------------------------------------------===// bool LLVMArrayType::isValidElementType(Type type) { - return !type.isa(); + return !llvm::isa( + type); } LLVMArrayType LLVMArrayType::get(Type elementType, unsigned numElements) { @@ -186,11 +187,11 @@ //===----------------------------------------------------------------------===// bool LLVMFunctionType::isValidArgumentType(Type type) { - return !type.isa(); + return !llvm::isa(type); } bool LLVMFunctionType::isValidResultType(Type type) { - return !type.isa(); + return !llvm::isa(type); } LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef arguments, @@ -239,9 +240,9 @@ if (!type) return true; return isCompatibleOuterType(type) - ? !type.isa() - : type.isa(); + ? !llvm::isa(type) + : llvm::isa(type); } LLVMPointerType LLVMPointerType::get(Type pointee, unsigned addressSpace) { @@ -266,7 +267,7 @@ std::optional mlir::LLVM::extractPointerSpecValue(Attribute attr, PtrDLEntryPos pos) { - auto spec = attr.cast(); + auto spec = llvm::cast(attr); auto idx = static_cast(pos); if (idx >= spec.size()) return std::nullopt; @@ -285,8 +286,8 @@ for (DataLayoutEntryInterface entry : params) { if (!entry.isTypeEntry()) continue; - if (entry.getKey().get().cast().getAddressSpace() == - type.getAddressSpace()) { + if (llvm::cast(entry.getKey().get()) + .getAddressSpace() == type.getAddressSpace()) { currentEntry = entry.getValue(); break; } @@ -350,11 +351,11 @@ continue; unsigned size = kDefaultPointerSizeBits; unsigned abi = kDefaultPointerAlignment; - auto newType = newEntry.getKey().get().cast(); + auto newType = llvm::cast(newEntry.getKey().get()); const auto *it = llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { if (auto type = entry.getKey().dyn_cast()) { - return type.cast().getAddressSpace() == + return llvm::cast(type).getAddressSpace() == newType.getAddressSpace(); } return false; @@ -362,7 +363,7 @@ if (it == oldLayout.end()) { llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { if (auto type = entry.getKey().dyn_cast()) { - return type.cast().getAddressSpace() == 0; + return llvm::cast(type).getAddressSpace() == 0; } return false; }); @@ -372,7 +373,7 @@ abi = *extractPointerSpecValue(*it, PtrDLEntryPos::Abi); } - Attribute newSpec = newEntry.getValue().cast(); + Attribute newSpec = llvm::cast(newEntry.getValue()); unsigned newSize = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Size); unsigned newAbi = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Abi); if (size != newSize || abi < newAbi || abi % newAbi != 0) @@ -386,8 +387,8 @@ for (DataLayoutEntryInterface entry : entries) { if (!entry.isTypeEntry()) continue; - auto key = entry.getKey().get().cast(); - auto values = entry.getValue().dyn_cast(); + auto key = llvm::cast(entry.getKey().get()); + auto values = llvm::dyn_cast(entry.getValue()); if (!values || (values.size() != 3 && values.size() != 4)) { return emitError(loc) << "expected layout attribute for " << entry.getKey().get() @@ -412,8 +413,9 @@ //===----------------------------------------------------------------------===// bool LLVMStructType::isValidElementType(Type type) { - return !type.isa(); + return !llvm::isa( + type); } LLVMStructType LLVMStructType::getIdentified(MLIRContext *context, @@ -538,7 +540,7 @@ if (currentEntry == params.end()) return std::nullopt; - auto attr = currentEntry->getValue().cast(); + auto attr = llvm::cast(currentEntry->getValue()); if (pos == StructDLEntryPos::Preferred && attr.size() <= static_cast(StructDLEntryPos::Preferred)) // If no preferred was specified, fall back to abi alignment @@ -586,7 +588,7 @@ } static unsigned extractStructSpecValue(Attribute attr, StructDLEntryPos pos) { - return attr.cast() + return llvm::cast(attr) .getValues()[static_cast(pos)]; } @@ -619,8 +621,8 @@ if (!entry.isTypeEntry()) continue; - auto key = entry.getKey().get().cast(); - auto values = entry.getValue().dyn_cast(); + auto key = llvm::cast(entry.getKey().get()); + auto values = llvm::dyn_cast(entry.getValue()); if (!values || (values.size() != 2 && values.size() != 1)) { return emitError(loc) << "expected layout attribute for " << entry.getKey().get() @@ -676,7 +678,7 @@ } bool LLVMFixedVectorType::isValidElementType(Type type) { - return type.isa(); + return llvm::isa(type); } LogicalResult @@ -705,10 +707,11 @@ } bool LLVMScalableVectorType::isValidElementType(Type type) { - if (auto intType = type.dyn_cast()) + if (auto intType = llvm::dyn_cast(type)) return intType.isSignless(); - return isCompatibleFloatingPointType(type) || type.isa(); + return isCompatibleFloatingPointType(type) || + llvm::isa(type); } LogicalResult @@ -724,7 +727,7 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) { // clang-format off - if (type.isa< + if (llvm::isa< BFloat16Type, Float16Type, Float32Type, @@ -745,17 +748,17 @@ LLVMScalableVectorType, LLVMVoidType, LLVMX86MMXType - >()) { + >(type)) { // clang-format on return true; } // Only signless integers are compatible. - if (auto intType = type.dyn_cast()) + if (auto intType = llvm::dyn_cast(type)) return intType.isSignless(); // 1D vector types are compatible. - if (auto vecType = type.dyn_cast()) + if (auto vecType = llvm::dyn_cast(type)) return vecType.getRank() == 1; return false; @@ -835,22 +838,22 @@ } bool mlir::LLVM::isCompatibleFloatingPointType(Type type) { - return type.isa(); + return llvm::isa(type); } bool mlir::LLVM::isCompatibleVectorType(Type type) { - if (type.isa()) + if (llvm::isa(type)) return true; - if (auto vecType = type.dyn_cast()) { + if (auto vecType = llvm::dyn_cast(type)) { if (vecType.getRank() != 1) return false; Type elementType = vecType.getElementType(); - if (auto intType = elementType.dyn_cast()) + if (auto intType = llvm::dyn_cast(elementType)) return intType.isSignless(); - return elementType.isa(); + return llvm::isa(elementType); } return false; } @@ -883,13 +886,12 @@ } bool mlir::LLVM::isScalableVectorType(Type vectorType) { - assert( - (vectorType - .isa()) && - "expected LLVM-compatible vector type"); - return !vectorType.isa() && - (vectorType.isa() || - vectorType.cast().isScalable()); + assert((llvm::isa( + vectorType)) && + "expected LLVM-compatible vector type"); + return !llvm::isa(vectorType) && + (llvm::isa(vectorType) || + llvm::cast(vectorType).isScalable()); } Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements, @@ -970,9 +972,9 @@ elementSize.isScalable()); }) .Default([](Type ty) { - assert((ty.isa()) && + assert((llvm::isa(ty)) && "unexpected missing support for primitive type"); return llvm::TypeSize::Fixed(0); }); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -90,13 +90,13 @@ return NVVM::MMATypes::f32; if (operandElType.isF32() && !isAccumulator) return NVVM::MMATypes::tf32; - if (operandElType.isa()) { + if (llvm::isa(operandElType)) { if (isAccumulator) return NVVM::MMATypes::s32; return std::nullopt; } - if (auto structType = operandElType.dyn_cast()) { + if (auto structType = llvm::dyn_cast(operandElType)) { if (structType.getBody().empty()) return std::nullopt; return inferOperandMMAType(structType.getBody()[0], isAccumulator); @@ -526,9 +526,9 @@ LogicalResult ShflOp::verify() { if (!(*this)->getAttrOfType("return_value_and_is_valid")) return success(); - auto type = getType().dyn_cast(); + auto type = llvm::dyn_cast(getType()); auto elementType = (type && type.getBody().size() == 2) - ? type.getBody()[1].dyn_cast() + ? llvm::dyn_cast(type.getBody()[1]) : nullptr; if (!elementType || elementType.getWidth() != 1) return emitError("expected return type to be a two-element struct with " @@ -600,7 +600,7 @@ LogicalResult NVVM::WMMALoadOp::verify() { unsigned addressSpace = - getPtr().getType().cast().getAddressSpace(); + llvm::cast(getPtr().getType()).getAddressSpace(); if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3) return emitOpError("expected source pointer in memory " "space 0, 1, 3"); @@ -620,7 +620,7 @@ LogicalResult NVVM::WMMAStoreOp::verify() { unsigned addressSpace = - getPtr().getType().cast().getAddressSpace(); + llvm::cast(getPtr().getType()).getAddressSpace(); if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3) return emitOpError("expected operands to be a source pointer in memory " "space 0, 1, 3"); @@ -672,7 +672,7 @@ LogicalResult NVVM::LdMatrixOp::verify() { unsigned addressSpace = - getPtr().getType().cast().getAddressSpace(); + llvm::cast(getPtr().getType()).getAddressSpace(); if (addressSpace != 3) return emitOpError("expected source pointer in memory space 3"); @@ -725,13 +725,13 @@ // If maxntid and reqntid exist, it must be an array with max 3 dim if (attrName == NVVMDialect::getMaxntidAttrName() || attrName == NVVMDialect::getReqntidAttrName()) { - auto values = attr.getValue().dyn_cast(); + auto values = llvm::dyn_cast(attr.getValue()); if (!values || values.empty() || values.size() > 3) return op->emitError() << "'" << attrName << "' attribute must be integer array with maximum 3 index"; - for (auto val : attr.getValue().cast()) { - if (!val.dyn_cast()) + for (auto val : llvm::cast(attr.getValue())) { + if (!llvm::dyn_cast(val)) return op->emitError() << "'" << attrName << "' attribute must be integer array with maximum 3 index"; @@ -740,7 +740,7 @@ // If minctasm and maxnreg exist, it must be an array with max 3 dim if (attrName == NVVMDialect::getMinctasmAttrName() || attrName == NVVMDialect::getMaxnregAttrName()) { - if (!attr.getValue().dyn_cast()) + if (!llvm::dyn_cast(attr.getValue())) return op->emitError() << "'" << attrName << "' attribute must be integer constant"; } diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp --- a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp @@ -25,11 +25,11 @@ /// Attempt to extract a filename for the given loc. static FileLineColLoc extractFileLoc(Location loc) { - if (auto fileLoc = loc.dyn_cast()) + if (auto fileLoc = llvm::dyn_cast(loc)) return fileLoc; - if (auto nameLoc = loc.dyn_cast()) + if (auto nameLoc = llvm::dyn_cast(loc)) return extractFileLoc(nameLoc.getChildLoc()); - if (auto opaqueLoc = loc.dyn_cast()) + if (auto opaqueLoc = llvm::dyn_cast(loc)) return extractFileLoc(opaqueLoc.getFallbackLocation()); return FileLineColLoc(); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -506,15 +506,15 @@ /// the type of `source`. static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) { - if (source.getType().isa()) + if (llvm::isa(source.getType())) return b.createOrFold(loc, source, dim); - if (source.getType().isa()) + if (llvm::isa(source.getType())) return b.createOrFold(loc, source, dim); llvm_unreachable("Expected MemRefType or TensorType"); } static OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) { - auto shapedType = source.getType().cast(); + auto shapedType = llvm::cast(source.getType()); if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) return createOrFoldDimOp(b, loc, source, dim); return b.getIndexAttr(shapedType.getDimSize(dim)); @@ -644,7 +644,7 @@ for (OpOperand *opOperand : getDpsInitOperands()) { SmallVector shapes; for (int64_t dim : llvm::seq(0, getRank(opOperand))) { - auto shapedType = opOperand->get().getType().cast(); + auto shapedType = llvm::cast(opOperand->get().getType()); if (!shapedType.isDynamicDim(dim)) { // Static dim: Return IntegerAttr. shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim))); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -63,7 +63,8 @@ TypeRange inputTypes, TypeRange outputTypes, ArrayRef attrs, RegionBuilderFn regionBuilder) { - assert(llvm::all_of(outputTypes, [](Type t) { return t.isa(); })); + assert(llvm::all_of(outputTypes, + [](Type t) { return llvm::isa(t); })); // TODO: atm all operands go through getElementTypeOrSelf, // reconsider when we have evidence we need to. @@ -106,7 +107,7 @@ resultTensorTypes.value_or(TypeRange()); if (!resultTensorTypes) copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes), - [](Type type) { return type.isa(); }); + [](Type type) { return llvm::isa(type); }); state.addOperands(inputs); state.addOperands(outputs); @@ -173,7 +174,7 @@ // Otherwise we append it to the discardable attributes dictionary where it is // handled by the generic Operation::create(...) method. if (result.propertiesAttr) { - NamedAttrList attrs = result.propertiesAttr.cast(); + NamedAttrList attrs = llvm::cast(result.propertiesAttr); attrs.append("operand_segment_sizes", parser.getBuilder().getDenseI32ArrayAttr( {static_cast(inputsOperands.size()), @@ -448,9 +449,15 @@ return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast); } - bool isComplex(Value value) { return value.getType().isa(); } - bool isFloatingPoint(Value value) { return value.getType().isa(); } - bool isInteger(Value value) { return value.getType().isa(); } + bool isComplex(Value value) { + return llvm::isa(value.getType()); + } + bool isFloatingPoint(Value value) { + return llvm::isa(value.getType()); + } + bool isInteger(Value value) { + return llvm::isa(value.getType()); + } OpBuilder getBuilder() { OpBuilder builder(context); @@ -748,8 +755,7 @@ for (auto attr : (*this)->getAttrs()) { if (attr.getName() == getIteratorTypesAttrName()) { auto iteratorTypes = - attr.getValue() - .cast() + llvm::cast(attr.getValue()) .getAsValueRange(); // Convert IteratorType enums into the string representation. This is // needed, because tests still use the old format when 'iterator_types' @@ -873,13 +879,13 @@ ValueRange results, const OpOperandVector &inputOperands, const OpOperandVector &outputOperands) { for (auto *operand : inputOperands) { - if (!operand->get().getType().isa()) + if (!llvm::isa(operand->get().getType())) continue; effects.emplace_back(MemoryEffects::Read::get(), operand->get(), SideEffects::DefaultResource::get()); } for (auto *operand : outputOperands) { - if (!operand->get().getType().isa()) + if (!llvm::isa(operand->get().getType())) continue; effects.emplace_back(MemoryEffects::Read::get(), operand->get(), SideEffects::DefaultResource::get()); @@ -942,7 +948,7 @@ // number to use for replacing uses of this operation. SmallVector returnedArgs; for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) { - auto yieldArg = yieldVal.value().dyn_cast(); + auto yieldArg = llvm::dyn_cast(yieldVal.value()); if (!yieldArg || yieldArg.getOwner() != &body) return failure(); unsigned argumentNumber = yieldArg.getArgNumber(); @@ -1003,7 +1009,7 @@ // Add result types. for (Type outputType : outputTypes) { - if (outputType.isa()) + if (llvm::isa(outputType)) result.addTypes(outputType); } @@ -1037,7 +1043,7 @@ // Add output types for `RankedTensorType` output arguments. Type initType = init.getType(); - if (initType.isa()) + if (llvm::isa(initType)) result.addTypes(initType); if (bodyBuild) @@ -1056,8 +1062,9 @@ b.setInsertionPointToStart(&block); SmallVector bbArgs; for (auto &operand : operands) { - block.addArgument(operand.getType().cast().getElementType(), - b.getUnknownLoc()); + block.addArgument( + llvm::cast(operand.getType()).getElementType(), + b.getUnknownLoc()); } SmallVector payloadOpOperands; // If initFirst flag is enabled, we consider init as the first position of @@ -1074,8 +1081,8 @@ Operation *payloadOp = b.create( result.location, b.getStringAttr(payloadOpName.getStringRef()), payloadOpOperands, - TypeRange{ - result.operands.back().getType().cast().getElementType()}, + TypeRange{llvm::cast(result.operands.back().getType()) + .getElementType()}, payloadOpAttrs); b.create(result.location, payloadOp->getResults()); } @@ -1151,7 +1158,8 @@ std::string attrToElide; p << " { " << payloadOp->getName().getStringRef(); for (const auto &attr : payloadOp->getAttrs()) { - auto fastAttr = attr.getValue().dyn_cast(); + auto fastAttr = + llvm::dyn_cast(attr.getValue()); if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) { attrToElide = attr.getName().str(); elidedAttrs.push_back(attrToElide); @@ -1200,7 +1208,8 @@ // The parameters of mapper should all match the element type of inputs. for (const auto &[bbArgType, inputArg] : llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) { - auto inputElemType = inputArg.getType().cast().getElementType(); + auto inputElemType = + llvm::cast(inputArg.getType()).getElementType(); if (bbArgType != inputElemType) { return emitOpError() << "expected element type of input " << inputElemType << " to match bbArg type " << bbArgType; @@ -1210,7 +1219,7 @@ // The shape of each input must match the shape of the output. auto outputShape = getInit().getType().getShape(); for (Type inputArgType : TypeRange{getInputs()}) { - auto inputElemShape = inputArgType.cast().getShape(); + auto inputElemShape = llvm::cast(inputArgType).getShape(); if (inputElemShape != outputShape) { return emitOpError() << "expected shape of input (" << inputElemShape << ") to match shape of output (" << outputShape @@ -1270,7 +1279,7 @@ // Add output types for `RankedTensorType` output arguments. for (Value init : inits) { Type initType = init.getType(); - if (initType.isa()) + if (llvm::isa(initType)) result.addTypes(initType); } @@ -1280,7 +1289,8 @@ } SmallVector ReduceOp::getIteratorTypesArray() { - int64_t inputRank = getInputs()[0].getType().cast().getRank(); + int64_t inputRank = + llvm::cast(getInputs()[0].getType()).getRank(); SmallVector iteratorTypes(inputRank, utils::IteratorType::parallel); for (int64_t reductionDim : getDimensions()) @@ -1289,7 +1299,8 @@ } ArrayAttr ReduceOp::getIndexingMaps() { - int64_t inputRank = getInputs()[0].getType().cast().getRank(); + int64_t inputRank = + llvm::cast(getInputs()[0].getType()).getRank(); SmallVector affineMaps( getNumDpsInputs(), AffineMap::getMultiDimIdentityMap(inputRank, getContext())); @@ -1390,8 +1401,8 @@ ArrayRef dimensionsRef = getDimensions(); for (int64_t i = 1; i < getNumDpsInputs(); ++i) { - if (getInputs()[i].getType().cast().getShape() != - getInputs()[0].getType().cast().getShape()) { + if (llvm::cast(getInputs()[i].getType()).getShape() != + llvm::cast(getInputs()[0].getType()).getShape()) { return emitOpError() << "expects all inputs to have the same shapes. " "Shape at input-index " << i @@ -1399,16 +1410,16 @@ } } for (int64_t i = 1; i < getNumDpsInits(); ++i) { - if (getInits()[i].getType().cast().getShape() != - getInits()[0].getType().cast().getShape()) { + if (llvm::cast(getInits()[i].getType()).getShape() != + llvm::cast(getInits()[0].getType()).getShape()) { return emitOpError() << "expects all outputs to have the same shapes. " "Shape at output-index " << i << " is not equal to the shape at output-index 0."; } } - auto inputType = getInputs()[0].getType().cast(); - auto initType = getInits()[0].getType().cast(); + auto inputType = llvm::cast(getInputs()[0].getType()); + auto initType = llvm::cast(getInits()[0].getType()); DenseSet dimensionsToReduce; for (int64_t dimension : dimensionsRef) { @@ -1449,7 +1460,8 @@ // Check that the first block arguments match the element type of the inputs. for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) { - Type inputElementType = input.getType().cast().getElementType(); + Type inputElementType = + llvm::cast(input.getType()).getElementType(); if (inputElementType != bbArg.getType()) return emitOpError() << "input element type " << inputElementType @@ -1462,7 +1474,7 @@ llvm::zip(getDpsInitOperands(), block->getArguments().take_back(getNumDpsInits()))) { auto outputElementType = - output->get().getType().cast().getElementType(); + llvm::cast(output->get().getType()).getElementType(); if (outputElementType != bbArg.getType()) return emitOpError() << "output element type " << outputElementType @@ -1496,7 +1508,7 @@ // Add output types for `RankedTensorType` output arguments. Type initType = init.getType(); - if (initType.isa()) + if (llvm::isa(initType)) result.addTypes(initType); buildIdentityRegion(builder, result.location, *result.addRegion(), input, @@ -1610,7 +1622,7 @@ // Add output types for `RankedTensorType` output arguments. Type initType = init.getType(); - if (initType.isa()) + if (llvm::isa(initType)) result.addTypes(initType); buildIdentityRegion(builder, result.location, *result.addRegion(), input, @@ -1828,7 +1840,7 @@ } static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) { - if (auto memref = t.dyn_cast()) { + if (auto memref = llvm::dyn_cast(t)) { ss << "view"; for (auto size : memref.getShape()) if (size < 0) @@ -1838,14 +1850,14 @@ if (failed(appendMangledType(ss, memref.getElementType()))) return failure(); if (auto as = memref.getMemorySpace()) { - if (auto attr = as.dyn_cast()) + if (auto attr = llvm::dyn_cast(as)) ss << "as" << attr.getInt(); else return failure(); } return success(); } - if (auto vec = t.dyn_cast()) { + if (auto vec = llvm::dyn_cast(t)) { ss << "vector"; llvm::interleave( vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; }); @@ -1864,9 +1876,9 @@ std::string name(op->getName().getStringRef().str()); std::string fun = ""; for (NamedAttribute kv : op->getAttrs()) { - if (UnaryFnAttr ufa = kv.getValue().dyn_cast()) { + if (UnaryFnAttr ufa = llvm::dyn_cast(kv.getValue())) { fun = stringifyEnum(ufa.getValue()).str() + "_"; - } else if (BinaryFnAttr bfa = kv.getValue().dyn_cast()) { + } else if (BinaryFnAttr bfa = llvm::dyn_cast(kv.getValue())) { fun = stringifyEnum(bfa.getValue()).str() + "_"; } } @@ -1898,7 +1910,7 @@ // Linalg "inputs" may be either tensor or memref type. // tensor<0xelt_type> is a convention that may not always mean // "0 iterations". Only erase in cases we see memref<...x0x...>. - auto mt = opOperand.get().getType().dyn_cast(); + auto mt = llvm::dyn_cast(opOperand.get().getType()); if (!mt) continue; if (llvm::is_contained(op.getShape(&opOperand), 0)) { @@ -1934,9 +1946,10 @@ rewriter.setInsertionPoint(linalgOp); Location loc = linalgOp.getLoc(); - OpResult resultValue = castOp.getSource().cast(); + OpResult resultValue = llvm::cast(castOp.getSource()); unsigned resultNumber = resultValue.getResultNumber(); - auto resultType = castOp->getResult(0).getType().cast(); + auto resultType = + llvm::cast(castOp->getResult(0).getType()); // Replace the `outs` for the result with a `tensor.cast`. This cast is now // going from a more dynamic shape to a less dynamic shape. If the producer // for this cast, i.e. producer of the out operand, is also an operation @@ -1975,7 +1988,7 @@ if (linalgOp.isScalar(&opOperand)) continue; Value src = opOperand.get(); - auto sourceType = src.getType().cast(); + auto sourceType = llvm::cast(src.getType()); auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand); // Get the `sourceShape` of the `sourceType`. If the operand is a result of @@ -1986,7 +1999,8 @@ if (parentOp) { if (auto castOp = dyn_cast(parentOp)) { Value castSource = castOp.getSource(); - auto castSourceType = castSource.getType().dyn_cast(); + auto castSourceType = + llvm::dyn_cast(castSource.getType()); if (castSourceType && castSourceType.hasStaticShape()) sourceShape = castSourceType.getShape(); } @@ -2017,7 +2031,7 @@ newOperands.push_back(src); if (linalgOp.isScalar(opOperand)) return; - auto sourceType = src.getType().cast(); + auto sourceType = llvm::cast(src.getType()); Type resultType = sourceType; if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) { resultTypes.push_back(resultType); diff --git a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp @@ -37,7 +37,7 @@ int64_t flatDimCtr = 0; for (Value operand : linalgOp->getOperands()) { assert(flatDimPos >= flatDimCtr && "invalid pos"); - auto shapedType = operand.getType().cast(); + auto shapedType = llvm::cast(operand.getType()); if (flatDimPos < flatDimCtr + shapedType.getRank()) { cstr.bound(value) < cstr.getExpr(operand, flatDimPos - flatDimCtr); break; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -87,7 +87,7 @@ SmallVector &result, ArrayRef ofrs) { for (OpFoldResult ofr : ofrs) { if (ofr.is()) { - if (!ofr.get().isa()) + if (!llvm::isa(ofr.get())) return transformOp.emitDefiniteFailure() << "expected IntegerAttr"; result.push_back(ofr); continue; @@ -155,7 +155,7 @@ llvm::map_range(state.getPayloadValues(getTarget()), [&](Value v) { return linalg::bufferizeToAllocation(rewriter, v, memorySpace); })); - results.setValues(getTransformed().cast(), transformed); + results.setValues(llvm::cast(getTransformed()), transformed); return DiagnosedSilenceableFailure::success(); } @@ -276,7 +276,7 @@ if (!sizesAttr) return parser.emitError(opLoc) << "expected '" << sizesAttrName << "' attribute"; - auto sizesArrayAttr = sizesAttr.dyn_cast(); + auto sizesArrayAttr = llvm::dyn_cast(sizesAttr); if (!sizesArrayAttr) return parser.emitError(opLoc) << "'" << sizesAttrName << "' attribute must be an array"; @@ -389,7 +389,7 @@ // Tile the producer. int64_t resultNumber = - sliceOpToTile.getSource().cast().getResultNumber(); + llvm::cast(sliceOpToTile.getSource()).getResultNumber(); LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); FailureOr tileAndFuseResult = @@ -411,9 +411,7 @@ // Replace the extract op. auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0], - sliceOpToTile->getResult(0) - .getType() - .cast() + llvm::cast(sliceOpToTile->getResult(0).getType()) .getShape()); assert(succeeded(maybeRankReduced) && "unexpected shape"); rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); @@ -482,7 +480,7 @@ // Replace the use in the tileableProducer before tiling: clone, replace and // then tile. - int64_t resultNumber = pUse->get().cast().getResultNumber(); + int64_t resultNumber = llvm::cast(pUse->get()).getResultNumber(); LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); // Gather destination tensors. @@ -516,9 +514,7 @@ // Replace the extract op. auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0], - sliceOpToTile->getResult(0) - .getType() - .cast() + llvm::cast(sliceOpToTile->getResult(0).getType()) .getShape()); assert(succeeded(maybeRankReduced) && "unexpected shape"); rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); @@ -568,7 +564,7 @@ // TODO: Generalize to other type of ops. assert(!isa(use->getOwner()) && "Parallel insert slice is not a valid clone destination"); - unsigned resultNumber = use->get().cast().getResultNumber(); + unsigned resultNumber = llvm::cast(use->get()).getResultNumber(); LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); OpBuilder::InsertionGuard guard(rewriter); @@ -587,7 +583,7 @@ ArrayRef producerOps = state.getPayloadOps(getProducerOp()); // If nothing to fuse, propagate success. if (producerOps.empty()) { - results.set(getFusedOp().cast(), + results.set(llvm::cast(getFusedOp()), SmallVector{}); return DiagnosedSilenceableFailure::success(); } @@ -671,7 +667,7 @@ return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } - results.set(getFusedOp().cast(), fusedOps); + results.set(llvm::cast(getFusedOp()), fusedOps); return DiagnosedSilenceableFailure::success(); } @@ -865,7 +861,7 @@ }; payloadOps.front()->walk(matchFun); - results.set(getResult().cast(), res); + results.set(llvm::cast(getResult()), res); return DiagnosedSilenceableFailure::success(); } @@ -901,7 +897,7 @@ DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( LinalgOp target, transform::ApplyToEachResultList &results, TransformState &state) { - if (getLowSize().getType().isa()) { + if (llvm::isa(getLowSize().getType())) { if (target.hasDynamicShape()) { auto diag = emitSilenceableError() << "cannot compute parametric tile sizes for dynamically " @@ -923,7 +919,7 @@ spec->lowTileSize * spec->lowTripCount}), [&builder, this](int64_t value) { return builder.getIntegerAttr( - getLowSize().getType().cast().getType(), value); + llvm::cast(getLowSize().getType()).getType(), value); })); return DiagnosedSilenceableFailure::success(); } @@ -958,7 +954,7 @@ SmallVectorImpl &effects) { onlyReadsHandle(getTarget(), effects); producesHandle(getResults(), effects); - if (getLowSize().getType().isa()) + if (llvm::isa(getLowSize().getType())) onlyReadsPayload(effects); else modifiesPayload(effects); @@ -1006,7 +1002,7 @@ ArrayRef targetOps = state.getPayloadOps(getTarget()); // If nothing to pack, propagate success. if (targetOps.empty()) { - transformResults.set(getPackedOp().cast(), {}); + transformResults.set(llvm::cast(getPackedOp()), {}); return DiagnosedSilenceableFailure::success(); } // Fail on multi-op handles. @@ -1036,7 +1032,7 @@ if (failed(maybeResult)) return emitDefiniteFailure("data tiling failed"); - transformResults.set(getPackedOp().cast(), + transformResults.set(llvm::cast(getPackedOp()), maybeResult->packedLinalgOp.getOperation()); return DiagnosedSilenceableFailure::success(); } @@ -1242,7 +1238,7 @@ } results.push_back(linalgOp); } - transformResults.set(getPackedOp().cast(), results); + transformResults.set(llvm::cast(getPackedOp()), results); return DiagnosedSilenceableFailure::success(); } @@ -1322,9 +1318,9 @@ ArrayRef linalgOps = state.getPayloadOps(getTargetLinalgOp()); // Step 1. If nothing to pack, propagate success. if (packOrUnpackOps.empty()) { - transformResults.set(getPackedOp().cast(), {}); - transformResults.set(getPackOp().cast(), {}); - transformResults.set(getUnPackOp().cast(), {}); + transformResults.set(llvm::cast(getPackedOp()), {}); + transformResults.set(llvm::cast(getPackOp()), {}); + transformResults.set(llvm::cast(getUnPackOp()), {}); return DiagnosedSilenceableFailure::success(); } @@ -1366,7 +1362,7 @@ if (unPackOp) { assert(!packOp && "packOp must be null on entry when unPackOp is not null"); OpOperand *packUse = linalgOp.getDpsInitOperand( - unPackOp.getSource().cast().getResultNumber()); + llvm::cast(unPackOp.getSource()).getResultNumber()); packOp = dyn_cast_or_null(packUse->get().getDefiningOp()); if (!packOp || !packOp.getResult().hasOneUse()) return emitSilenceableError() << "could not find matching pack op"; @@ -1400,14 +1396,15 @@ assert(succeeded(res) && "unexpected packTranspose failure"); // Step 4. Return results. - transformResults.set(getPackOp().cast(), {res->transposedPackOp}); - transformResults.set(getPackedOp().cast(), + transformResults.set(llvm::cast(getPackOp()), + {res->transposedPackOp}); + transformResults.set(llvm::cast(getPackedOp()), {res->transposedLinalgOp}); if (unPackOp) { - transformResults.set(getUnPackOp().cast(), + transformResults.set(llvm::cast(getUnPackOp()), {res->transposedUnPackOp}); } else { - transformResults.set(getUnPackOp().cast(), {}); + transformResults.set(llvm::cast(getUnPackOp()), {}); } return DiagnosedSilenceableFailure::success(); @@ -1430,14 +1427,14 @@ SmallVector paddingValues; for (auto const &it : llvm::zip(getPaddingValues(), target->getOperandTypes())) { - auto attr = std::get<0>(it).dyn_cast(); + auto attr = llvm::dyn_cast(std::get<0>(it)); if (!attr) { emitOpError("expects padding values to be typed attributes"); return DiagnosedSilenceableFailure::definiteFailure(); } Type elementType = getElementTypeOrSelf(std::get<1>(it)); // Try to parse string attributes to obtain an attribute of element type. - if (auto stringAttr = attr.dyn_cast()) { + if (auto stringAttr = llvm::dyn_cast(attr)) { auto parsedAttr = dyn_cast_if_present( parseAttribute(stringAttr, getContext(), elementType, /*numRead=*/nullptr, /*isKnownNullTerminated=*/true)); @@ -1462,9 +1459,10 @@ // Extract the transpose vectors. SmallVector> transposePaddings; - for (Attribute transposeVector : getTransposePaddings().cast()) + for (Attribute transposeVector : + llvm::cast(getTransposePaddings())) transposePaddings.push_back( - extractFromI64ArrayAttr(transposeVector.cast())); + extractFromI64ArrayAttr(llvm::cast(transposeVector))); TrackingListener listener(state, *this); IRRewriter rewriter(getContext(), &listener); @@ -1549,13 +1547,13 @@ return emitDefiniteFailure() << "could not build packing loop nest"; if (result->clonedLoopIvs.empty()) { - transformResults.set(getPackingLoop().cast(), + transformResults.set(llvm::cast(getPackingLoop()), result->hoistedPadOp.getOperation()); return DiagnosedSilenceableFailure::success(); } auto outerPackedLoop = scf::getForInductionVarOwner(result->clonedLoopIvs.front()); - transformResults.set(getPackingLoop().cast(), + transformResults.set(llvm::cast(getPackingLoop()), outerPackedLoop.getOperation()); return DiagnosedSilenceableFailure::success(); } @@ -1643,7 +1641,7 @@ if (mapping.size() > 1) return emitDefaultDefiniteFailure(target); - auto addressSpace = mapping[0].cast(); + auto addressSpace = llvm::cast(mapping[0]); if (addressSpace.getAddressSpace() == gpu::GPUDialect::getWorkgroupAddressSpace()) { @@ -1711,7 +1709,7 @@ rewriter.replaceOp(target, replacement->getResults()); replacements.push_back(replacement); } - transformResults.set(getReplacement().cast(), replacements); + transformResults.set(llvm::cast(getReplacement()), replacements); return DiagnosedSilenceableFailure::success(); } @@ -1828,7 +1826,8 @@ splitPoints.reserve(payload.size()); if (getDynamicSplitPoint()) { auto diag = DiagnosedSilenceableFailure::success(); - if (getDynamicSplitPoint().getType().isa()) { + if (llvm::isa( + getDynamicSplitPoint().getType())) { splitPoints = llvm::to_vector(llvm::map_range( state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) { if (op->getNumResults() != 1 || @@ -1909,8 +1908,8 @@ return diag; } - results.set(getFirst().cast(), first); - results.set(getSecond().cast(), second); + results.set(llvm::cast(getFirst()), first); + results.set(llvm::cast(getSecond()), second); return DiagnosedSilenceableFailure::success(); } @@ -2212,12 +2211,12 @@ dynamicSizeProducers.reserve(getDynamicSizes().size()); paramSizes.reserve(getDynamicSizes().size()); for (Value transformValue : getDynamicSizes()) { - if (transformValue.getType().isa()) { + if (llvm::isa(transformValue.getType())) { dynamicSizeProducers.push_back({}); ArrayRef params = state.getParams(transformValue); paramSizes.push_back( llvm::to_vector(llvm::map_range(params, [](Attribute attr) { - return attr.cast().getValue().getSExtValue(); + return llvm::cast(attr).getValue().getSExtValue(); }))); if (paramSizes.back().size() != targets.size()) { @@ -2247,7 +2246,7 @@ for (Operation *op : dynamicSizeProducers.back()) { if (op->getNumResults() == 1 && - op->getResult(0).getType().isa()) + llvm::isa(op->getResult(0).getType())) continue; DiagnosedSilenceableFailure diag = @@ -2283,7 +2282,7 @@ for (OpFoldResult ofr : getMixedSizes()) { if (auto attr = ofr.dyn_cast()) { sizes.push_back(b.create( - getLoc(), attr.cast().getInt())); + getLoc(), llvm::cast(attr).getInt())); continue; } ArrayRef dynamicSizes = dynamicSizeProducers[dynamicIdx]; @@ -2320,9 +2319,10 @@ loops[en2.index()].push_back(en2.value()); } - transformResults.set(getTiledLinalgOp().cast(), tiled); + transformResults.set(llvm::cast(getTiledLinalgOp()), tiled); for (const auto &en : llvm::enumerate(loops)) - transformResults.set(getLoops()[en.index()].cast(), en.value()); + transformResults.set(llvm::cast(getLoops()[en.index()]), + en.value()); return DiagnosedSilenceableFailure::success(); } @@ -2582,8 +2582,8 @@ tiledOps.push_back(tilingResult.tiledOp); } - transformResults.set(getForallOp().cast(), tileOps); - transformResults.set(getTiledOp().cast(), tiledOps); + transformResults.set(llvm::cast(getForallOp()), tileOps); + transformResults.set(llvm::cast(getTiledOp()), tiledOps); return DiagnosedSilenceableFailure::success(); } @@ -2678,7 +2678,7 @@ for (Operation *op : dynamicSizeProducers.back()) { if (op->getNumResults() == 1 && - op->getResult(0).getType().isa()) + llvm::isa(op->getResult(0).getType())) continue; DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected sizes to be produced by ops " @@ -2712,7 +2712,7 @@ for (OpFoldResult ofr : getMixedSizes()) { if (auto attr = ofr.dyn_cast()) { sizes.push_back(b.create( - getLoc(), attr.cast().getInt())); + getLoc(), llvm::cast(attr).getInt())); } else { sizes.push_back( dynamicSizeProducers[dynamicIdx++][index]->getResult(0)); @@ -2737,9 +2737,10 @@ loops[en2.index()].push_back(en2.value()); } - transformResults.set(getTiledLinalgOp().cast(), tiled); + transformResults.set(llvm::cast(getTiledLinalgOp()), tiled); for (const auto &en : llvm::enumerate(loops)) - transformResults.set(getLoops()[en.index()].cast(), en.value()); + transformResults.set(llvm::cast(getLoops()[en.index()]), + en.value()); return DiagnosedSilenceableFailure::success(); } @@ -2899,7 +2900,7 @@ for (OpFoldResult sz : getMixedVectorSizes()) { if (sz.is()) { auto attr = sz.get(); - vectorSizes.push_back(attr.cast().getInt()); + vectorSizes.push_back(llvm::cast(attr).getInt()); continue; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp @@ -64,20 +64,21 @@ if (genericOp.getNumDpsInits() != 1) return failure(); - auto outputType = genericOp.getResultTypes().front().dyn_cast(); + auto outputType = + llvm::dyn_cast(genericOp.getResultTypes().front()); // Require the output types to be static given that we are generating // constants. if (!outputType || !outputType.hasStaticShape()) return failure(); if (!llvm::all_of(genericOp.getInputs(), [](Value input) { - return input.getType().isa(); + return llvm::isa(input.getType()); })) return failure(); // Make sure all element types are the same. auto getOperandElementType = [](Value value) { - return value.getType().cast().getElementType(); + return llvm::cast(value.getType()).getElementType(); }; if (!llvm::all_equal( llvm::map_range(genericOp->getOperands(), getOperandElementType))) @@ -138,7 +139,7 @@ // unify the following cases but they have lifetime as the MLIRContext. SmallVector intOutputValues; SmallVector fpOutputValues; - if (elementType.template isa()) + if (llvm::isa(elementType)) fpOutputValues.resize(numElements, APFloat(0.f)); else intOutputValues.resize(numElements); @@ -174,7 +175,7 @@ auto inputShapes = llvm::to_vector<4>( llvm::map_range(genericOp.getInputs(), [](Value value) { - return value.getType().cast().getShape(); + return llvm::cast(value.getType()).getShape(); })); // Given a `linearIndex`, remap it to a linear index to access linalg op @@ -205,7 +206,7 @@ } }; - bool isFloat = elementType.isa(); + bool isFloat = llvm::isa(elementType); if (isFloat) { SmallVector> inFpRanges; for (int i = 0; i < numInputs; ++i) @@ -282,7 +283,7 @@ // The yield op should return the block argument corresponds to the input. for (Value yieldVal : yieldOp.getValues()) { - auto yieldArg = yieldVal.dyn_cast(); + auto yieldArg = llvm::dyn_cast(yieldVal); if (!yieldArg || yieldArg.getOwner() != &body) return nullptr; if (yieldArg.getArgNumber() != 0) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp @@ -29,7 +29,7 @@ } static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) { - bool isInt = x.getType().isa(); + bool isInt = llvm::isa(x.getType()); if (isInt) return builder.create(loc, x, y); return builder.create(loc, x, y); @@ -42,7 +42,7 @@ convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false); Value yConvert = convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false); - if (accType.isa()) + if (llvm::isa(accType)) return builder.create(loc, xConvert, yConvert); return builder.create(loc, xConvert, yConvert); } @@ -74,9 +74,9 @@ FailureOr> rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { - auto inputType = convOp.getInputs()[0].getType().cast(); - auto filterType = convOp.getInputs()[1].getType().cast(); - auto outputType = convOp.getOutputs()[0].getType().cast(); + auto inputType = llvm::cast(convOp.getInputs()[0].getType()); + auto filterType = llvm::cast(convOp.getInputs()[1].getType()); + auto outputType = llvm::cast(convOp.getOutputs()[0].getType()); if (!filterType.hasStaticShape()) return rewriter.notifyMatchFailure( @@ -210,9 +210,12 @@ FailureOr> rewriteInIm2Col(RewriterBase &rewriter, linalg::DepthwiseConv2DNhwcHwcOp convOp) { - auto inputType = convOp.getInputs()[0].getType().cast(); - auto filterType = convOp.getInputs()[1].getType().cast(); - auto outputType = convOp.getOutputs()[0].getType().cast(); + auto inputType = + llvm::cast(convOp.getInputs()[0].getType()); + auto filterType = + llvm::cast(convOp.getInputs()[1].getType()); + auto outputType = + llvm::cast(convOp.getOutputs()[0].getType()); if (!filterType.hasStaticShape()) return rewriter.notifyMatchFailure( @@ -230,7 +233,7 @@ Location loc = convOp.getLoc(); auto transposeOperand = [&](Value operand, ArrayRef indices) { - auto operandTensorType = operand.getType().cast(); + auto operandTensorType = llvm::cast(operand.getType()); auto nloops = indices.size(); ArrayRef inputShape = operandTensorType.getShape(); @@ -272,7 +275,7 @@ Value inputT = transposeOperand(input, {0, 3, 1, 2}); Value filterT = transposeOperand(filter, {2, 0, 1}); ArrayRef filterTShape = - filterT.getType().cast().getShape(); + llvm::cast(filterT.getType()).getShape(); ArrayRef outputShape = outputType.getShape(); int n = outputShape[0]; @@ -360,9 +363,9 @@ FailureOr> rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { - auto inputType = convOp.getInputs()[0].getType().cast(); - auto filterType = convOp.getInputs()[1].getType().cast(); - auto outputType = convOp.getOutputs()[0].getType().cast(); + auto inputType = llvm::cast(convOp.getInputs()[0].getType()); + auto filterType = llvm::cast(convOp.getInputs()[1].getType()); + auto outputType = llvm::cast(convOp.getOutputs()[0].getType()); if (!filterType.hasStaticShape()) return rewriter.notifyMatchFailure( diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -66,12 +66,12 @@ Attribute constYieldedValue; // Is the yielded value a bbArg defined outside of the PadOp? bool outsideBbArg = - yieldedValue.isa() && - yieldedValue.cast().getOwner()->getParentOp() != + llvm::isa(yieldedValue) && + llvm::cast(yieldedValue).getOwner()->getParentOp() != padOp.getOperation(); // Is the yielded value an OpResult defined outside of the PadOp? bool outsideOpResult = - yieldedValue.isa() && + llvm::isa(yieldedValue) && yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation(); bool invariantYieldedValue = outsideBbArg || outsideOpResult; if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) { @@ -120,19 +120,19 @@ static SmallVector reifyOrComputeDynamicSizes(OpBuilder &b, Value value) { - auto tensorType = value.getType().cast(); + auto tensorType = llvm::cast(value.getType()); if (tensorType.hasStaticShape()) return {}; // Try to reify dynamic sizes. ReifiedRankedShapedTypeDims reifiedShape; - if (value.isa() && + if (llvm::isa(value) && succeeded(reifyResultShapes(b, value.getDefiningOp(), reifiedShape))) { SmallVector dynSizes; for (int64_t i = 0; i < tensorType.getRank(); ++i) { if (tensorType.isDynamicDim(i)) dynSizes.push_back( - reifiedShape[value.cast().getResultNumber()][i] + reifiedShape[llvm::cast(value).getResultNumber()][i] .get()); } return dynSizes; @@ -153,12 +153,12 @@ Value value, Attribute memorySpace = {}) { OpBuilder::InsertionGuard g(rewriter); - auto tensorType = value.getType().cast(); + auto tensorType = llvm::cast(value.getType()); // Create buffer allocation. - auto memrefType = bufferization::getMemRefTypeWithStaticIdentityLayout( - tensorType, memorySpace) - .cast(); + auto memrefType = llvm::cast( + bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, + memorySpace)); SmallVector dynamicSizes = reifyOrComputeDynamicSizes(rewriter, value); Value alloc = rewriter.create(loc, memrefType, dynamicSizes); @@ -206,7 +206,7 @@ RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) { Location loc = fromElementsOp.getLoc(); RankedTensorType tensorType = - fromElementsOp.getType().cast(); + llvm::cast(fromElementsOp.getType()); auto shape = tensorType.getShape(); // Create tensor.empty. @@ -247,7 +247,8 @@ return failure(); Location loc = generateOp.getLoc(); - RankedTensorType tensorType = generateOp.getType().cast(); + RankedTensorType tensorType = + llvm::cast(generateOp.getType()); // Create tensor.empty. auto emptyOp = @@ -339,7 +340,7 @@ llvm::map_range(value.getUses(), [](OpOperand &use) { return &use; })); OpBuilder::InsertionGuard g(rewriter); - if (auto bbArg = value.dyn_cast()) { + if (auto bbArg = llvm::dyn_cast(value)) { rewriter.setInsertionPointToStart(bbArg.getOwner()); } else { rewriter.setInsertionPointAfter(value.getDefiningOp()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -639,8 +639,8 @@ // dynamic case we always need a new destination. auto loc = genericOp.getLoc(); Value unPackDest = producerUnPackOp.getDest(); - auto genericOutType = - genericOp.getDpsInitOperand(0)->get().getType().cast(); + auto genericOutType = llvm::cast( + genericOp.getDpsInitOperand(0)->get().getType()); if (producerUnPackOp.getDestType() != genericOutType || !genericOutType.hasStaticShape()) { unPackDest = tensor::UnPackOp::createDestinationTensor( diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp @@ -132,12 +132,12 @@ static Value getZero(OpBuilder &b, Location loc, Type elementType) { assert(elementType.isIntOrIndexOrFloat() && "expected scalar type while computing zero value"); - if (elementType.isa()) + if (llvm::isa(elementType)) return b.create(loc, 0, elementType); if (elementType.isIndex()) return b.create(loc, 0); // Assume float. - auto floatType = elementType.cast(); + auto floatType = llvm::cast(elementType); return b.create( loc, APFloat::getZero(floatType.getFloatSemantics()), floatType); } @@ -179,7 +179,8 @@ if (resultNumber) { newInitValues.push_back( genericOp.getDpsInitOperand(*resultNumber)->get()); - OpResult result = genericOp.getResult(*resultNumber).cast(); + OpResult result = + llvm::cast(genericOp.getResult(*resultNumber)); newResultTypes.push_back(result.getType()); peeledGenericOpIndexingMaps.push_back( genericOp.getIndexingMapMatchingResult(result)); @@ -231,7 +232,8 @@ })); for (auto resultNum : llvm::seq(origNumResults, peeledGenericOpNumResults)) { - OpResult result = peeledGenericOp.getResult(resultNum).cast(); + OpResult result = + llvm::cast(peeledGenericOp.getResult(resultNum)); indexingMaps.push_back( peeledGenericOp.getIndexingMapMatchingResult(result)); } @@ -348,7 +350,7 @@ /// the peeled operation. SmallVector replacements; for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) { - OpResult opr = yieldValue.value().dyn_cast(); + OpResult opr = llvm::dyn_cast(yieldValue.value()); if (!opr || opr.getOwner() != peeledScalarOperation) replacements.push_back(residualGenericOp.getResult(yieldValue.index())); else diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -32,7 +32,7 @@ ValueRange inputs, Location loc) { assert(inputs.size() == 1); auto inputType = inputs[0].getType(); - if (inputType.isa()) + if (llvm::isa(inputType)) return nullptr; // A detensored value is converted back by creating a new tensor from its @@ -320,9 +320,9 @@ // * Add the argument to blockArgsToDetensor. // * Walk the use-def chain backwards to add each predecessor's // terminator-operands corresponding to currentItem to workList. - if (currentItem.dyn_cast()) { + if (llvm::dyn_cast(currentItem)) { BlockArgument currentItemBlockArgument = - currentItem.cast(); + llvm::cast(currentItem); Block *ownerBlock = currentItemBlockArgument.getOwner(); // Function arguments are not detensored/converted. diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -308,7 +308,8 @@ for (OpOperand *op : candidates) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfterValue(op->get()); - auto elemType = op->get().getType().cast().getElementType(); + auto elemType = + llvm::cast(op->get().getType()).getElementType(); auto empty = rewriter.create( loc, tensor::createDimValues(rewriter, loc, op->get()), elemType); @@ -387,7 +388,7 @@ // Early return for memrefs with affine maps to represent that we will always // leave them unchanged. Type actualType = opOperand->get().getType(); - if (auto memref = actualType.dyn_cast()) { + if (auto memref = llvm::dyn_cast(actualType)) { if (!memref.getLayout().isIdentity()) return std::nullopt; } @@ -437,7 +438,7 @@ ArrayRef reassociation, Location loc, PatternRewriter &rewriter) const { // There are no results for memref outputs. - auto origResultType = origOutput.getType().cast(); + auto origResultType = llvm::cast(origOutput.getType()); if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { unsigned rank = origResultType.getRank(); SmallVector offsets(rank, rewriter.getIndexAttr(0)); @@ -459,7 +460,7 @@ Value collapseValue(Value operand, ArrayRef targetShape, ArrayRef reassociation, Location loc, PatternRewriter &rewriter) const { - if (auto memrefType = operand.getType().dyn_cast()) { + if (auto memrefType = llvm::dyn_cast(operand.getType())) { if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { FailureOr rankReducingExtract = memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand, @@ -478,7 +479,7 @@ return rewriter.create(loc, targetType, operand, reassociation); } - if (auto tensorType = operand.getType().dyn_cast()) { + if (auto tensorType = llvm::dyn_cast(operand.getType())) { if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { FailureOr rankReducingExtract = tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand, @@ -502,7 +503,7 @@ PatternRewriter &rewriter) const override { // Skip the pattern if the op has any tensor with special encoding. if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) { - auto tensorType = type.dyn_cast(); + auto tensorType = llvm::dyn_cast(type); return tensorType && tensorType.getEncoding() != nullptr; })) return failure(); @@ -607,11 +608,10 @@ if (!reassociation || reassociation->size() == static_cast(resultType.getRank())) return failure(); - auto rankReducedType = + auto rankReducedType = llvm::cast( tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( reassociation->size(), sliceOp.getSourceType(), offsets, sizes, - strides) - .cast(); + strides)); Location loc = sliceOp.getLoc(); Value newSlice = rewriter.create( diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -87,7 +87,7 @@ // type. Producer must have full tensor semantics to avoid potential // aliasing between producer and consumer memrefs. if (!producer.hasTensorSemantics() || - !fusedOperand->get().getType().isa()) + !llvm::isa(fusedOperand->get().getType())) return false; // Verify that @@ -232,14 +232,14 @@ // forward the yield operand. auto producerYieldOp = cast(producerBlock.getTerminator()); unsigned producerResultNumber = - fusedOperand->get().cast().getResultNumber(); + llvm::cast(fusedOperand->get()).getResultNumber(); Value replacement = mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber)); // Sanity checks, if replacement is not already in the mapper then it must be // produced outside. if (replacement == producerYieldOp.getOperand(producerResultNumber)) { - if (auto bb = replacement.dyn_cast()) + if (auto bb = llvm::dyn_cast(replacement)) assert(bb.getOwner() != &producerBlock && "yielded block argument must have been mapped"); else @@ -278,7 +278,7 @@ OpOperand *fusedOperand) { assert(areElementwiseOpsFusable(fusedOperand) && "expected elementwise operation pre-conditions to pass"); - auto producerResult = fusedOperand->get().cast(); + auto producerResult = llvm::cast(fusedOperand->get()); auto producer = cast(producerResult.getOwner()); auto consumer = cast(fusedOperand->getOwner()); // TODO: allow fusing the producer of an output operand. @@ -357,7 +357,7 @@ fusedOutputOperands.push_back(opOperand->get()); fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand)); Type resultType = opOperand->get().getType(); - if (!resultType.isa()) + if (!llvm::isa(resultType)) fusedResultTypes.push_back(resultType); } @@ -512,7 +512,7 @@ return genericOp.hasTensorSemantics() && llvm::all_of(genericOp.getIndexingMaps().getValue(), [](Attribute attr) { - return attr.cast() + return llvm::cast(attr) .getValue() .isProjectedPermutation(); }) && @@ -776,7 +776,7 @@ continue; } if (auto opOperandType = - opOperand->get().getType().dyn_cast()) { + llvm::dyn_cast(opOperand->get().getType())) { AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); RankedTensorType expandedOperandType = getExpandedType(opOperandType, indexingMap, expansionInfo); @@ -805,7 +805,8 @@ SmallVector outputs; for (OpOperand *opOperand : genericOp.getDpsInitOperands()) { AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); - auto opOperandType = opOperand->get().getType().cast(); + auto opOperandType = + llvm::cast(opOperand->get().getType()); RankedTensorType expandedOutputType = getExpandedType(opOperandType, indexingMap, expansionInfo); if (expandedOutputType != opOperand->get().getType()) { @@ -921,7 +922,7 @@ LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, PatternRewriter &rewriter) const override { // Fold only if all constraints of fusing with reshape by expansion are met. - auto producerResult = reshapeOp.getSrc().dyn_cast(); + auto producerResult = llvm::dyn_cast(reshapeOp.getSrc()); if (!producerResult) { return rewriter.notifyMatchFailure(reshapeOp, "source not produced by an operation"); @@ -959,8 +960,9 @@ // same type as the returns of the original generic op, the consumer reshape // op can be replaced by the source of the collapse_shape op that defines // the replacement. - Value reshapeReplacement = (*replacementValues) - [reshapeOp.getSrc().cast().getResultNumber()]; + Value reshapeReplacement = + (*replacementValues)[llvm::cast(reshapeOp.getSrc()) + .getResultNumber()]; if (auto collapseOp = reshapeReplacement.getDefiningOp()) { reshapeReplacement = collapseOp.getSrc(); @@ -1447,7 +1449,7 @@ .createLoopRanges(rewriter, genericOp.getLoc()); auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) { if (auto attr = ofr.dyn_cast()) - return attr.cast().getInt() == value; + return llvm::cast(attr).getInt() == value; llvm::APInt actual; return matchPattern(ofr.get(), m_ConstantInt(&actual)) && actual.getSExtValue() == value; @@ -1521,8 +1523,9 @@ Value collapsedOpResult = collapsedGenericOp->getResult(originalResult.index()); auto originalResultType = - originalResult.value().getType().cast(); - auto collapsedOpResultType = collapsedOpResult.getType().cast(); + llvm::cast(originalResult.value().getType()); + auto collapsedOpResultType = + llvm::cast(collapsedOpResult.getType()); if (collapsedOpResultType.getRank() != originalResultType.getRank()) { AffineMap indexingMap = genericOp.getIndexingMapMatchingResult(originalResult.value()); @@ -1671,7 +1674,7 @@ return false; }; - auto resultValue = opOperand->get().dyn_cast(); + auto resultValue = llvm::dyn_cast(opOperand->get()); if (!def || !resultValue || !isScalarOrSplatConstantOp(def)) continue; @@ -1756,7 +1759,8 @@ for (OpOperand *opOperand : op.getDpsInitOperands()) { if (!op.payloadUsesValueFromOperand(opOperand)) { Value operandVal = opOperand->get(); - auto operandType = operandVal.getType().dyn_cast(); + auto operandType = + llvm::dyn_cast(operandVal.getType()); if (!operandType) continue; @@ -1809,8 +1813,8 @@ continue; fillFound = true; Value fillVal = fillOp.value(); - auto resultType = - fillOp.result().getType().cast().getElementType(); + auto resultType = llvm::cast(fillOp.result().getType()) + .getElementType(); Value convertedVal = convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType, /*isUnsignedCast =*/false); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -27,8 +27,9 @@ // TODO: The conversion pattern can be made to work for `any_of` here, but // it's more complex as it requires tracking which operands are scalars. - return llvm::all_of(op->getOperandTypes(), - [](Type type) { return type.isa(); }); + return llvm::all_of(op->getOperandTypes(), [](Type type) { + return llvm::isa(type); + }); } /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over @@ -67,7 +68,7 @@ // Extract static / dynamic shape mix from the first operand. Value firstOperand = operands.front(); - auto rankedTensorType = t.cast(); + auto rankedTensorType = llvm::cast(t); auto staticShape = llvm::to_vector<4>(rankedTensorType.getShape()); auto dynamicShape = linalg::createDynamicDimensions(b, loc, firstOperand); @@ -87,7 +88,8 @@ return rewriter.notifyMatchFailure( op, "requires elementwise op on ranked tensors"); - auto rank = op->getResult(0).getType().cast().getRank(); + auto rank = + llvm::cast(op->getResult(0).getType()).getRank(); SmallVector indexingMaps( op->getNumResults() + op->getNumOperands(), rewriter.getMultiDimIdentityMap(rank)); @@ -104,7 +106,7 @@ [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { auto resultTypes = llvm::to_vector<6>( llvm::map_range(op->getResultTypes(), [](Type type) { - return type.cast().getElementType(); + return llvm::cast(type).getElementType(); })); auto *scalarOp = builder.create(loc, op->getName().getIdentifier(), diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp @@ -89,7 +89,7 @@ Location loc = genericOp.getLoc(); SmallVector newResultTypes; for (Value v : newOutputOperands) - if (v.getType().isa()) + if (llvm::isa(v.getType())) newResultTypes.push_back(v.getType()); auto newOp = rewriter.create( loc, newResultTypes, newInputOperands, newOutputOperands, diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp @@ -86,12 +86,12 @@ // result of the generic op. The low pad values are the offsets, the size of // the source is the size of the slice. // TODO: This insert/extract could be potentially made a utility method. - unsigned resultNumber = source.cast().getResultNumber(); + unsigned resultNumber = llvm::cast(source).getResultNumber(); SmallVector offsets = padOp.getMixedLowPad(); SmallVector sizes; sizes.reserve(offsets.size()); for (const auto &shape : llvm::enumerate( - source.getType().cast().getShape())) { + llvm::cast(source.getType()).getShape())) { if (ShapedType::isDynamic(shape.value())) { sizes.push_back( rewriter.create(loc, source, shape.index()) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -151,7 +151,8 @@ SmallVector resultTypes; resultTypes.reserve(producer->getNumResults()); for (OpOperand *operand : producer.getDpsInitOperands()) { - auto tensorType = operand->get().getType().dyn_cast(); + auto tensorType = + llvm::dyn_cast(operand->get().getType()); if (!tensorType) continue; unsigned rank = tensorType.getRank(); @@ -210,20 +211,20 @@ // dependence tracking since the dependence tracking is similar to what is done // w.r.t to buffers. static void getProducerOfTensor(Value tensor, OpResult &opResult) { - if (!tensor.getType().isa()) + if (!llvm::isa(tensor.getType())) return; while (true) { LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor); if (auto linalgOp = tensor.getDefiningOp()) { - opResult = tensor.cast(); + opResult = llvm::cast(tensor); return; } if (auto sliceOp = tensor.getDefiningOp()) { tensor = sliceOp.getSource(); continue; } - if (auto blockArg = tensor.dyn_cast()) { + if (auto blockArg = llvm::dyn_cast(tensor)) { if (auto forOp = blockArg.getDefiningOp()) { tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber()); continue; diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -227,7 +227,7 @@ return {}; bbArgs.push_back(bbArg); OpOperand *iterArg = &tileLoop.getOpOperandForRegionIterArg(bbArg); - bbArg = iterArg->get().dyn_cast(); + bbArg = llvm::dyn_cast(iterArg->get()); } // Reverse the block arguments to order them from outer to inner. @@ -358,13 +358,13 @@ // Check if the producer is a LinalgOp possibly passed by iteration argument. OpOperand *iterArg = nullptr; - auto producerResult = sliceOp.getSource().dyn_cast(); - if (auto bbArg = sliceOp.getSource().dyn_cast()) { + auto producerResult = llvm::dyn_cast(sliceOp.getSource()); + if (auto bbArg = llvm::dyn_cast(sliceOp.getSource())) { iterArg = getTiedIterArg(bbArg); // Check the iteration argument may be used to pass in the producer output. if (!iterArg || hasOtherUses(bbArg, sliceOp)) return failure(); - producerResult = iterArg->get().dyn_cast(); + producerResult = llvm::dyn_cast(iterArg->get()); } if (!producerResult || !isa(producerResult.getOwner())) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -549,7 +549,7 @@ int paddedRank = paddedTensorType.getRank(); // Step 0. Populate bvm with opToHoist.getSource if relevant. - BlockArgument bbArg = opToHoist.getSource().dyn_cast(); + BlockArgument bbArg = llvm::dyn_cast(opToHoist.getSource()); while (bbArg) { auto forOp = dyn_cast(bbArg.getOwner()->getParentOp()); if (!forOp) @@ -558,7 +558,7 @@ break; OpOperand &operand = forOp.getOpOperandForRegionIterArg(bbArg); bvm.map(bbArg, operand.get()); - bbArg = operand.get().dyn_cast(); + bbArg = llvm::dyn_cast(operand.get()); } // Step 1. iteratively clone loops and push `hoistedPackedTensor`. @@ -754,7 +754,8 @@ break; LLVM_DEBUG(DBGS() << "--step dest op: " << destOp << "\n"); source = - destOp.getDpsInitOperand(source.cast().getResultNumber()) + destOp + .getDpsInitOperand(llvm::cast(source).getResultNumber()) ->get(); } LLVM_DEBUG(DBGS() << "--final source: " << source << "\n"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -86,7 +86,7 @@ [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); func.walk([&](vector::TransferReadOp transferRead) { - if (!transferRead.getShapedType().isa()) + if (!llvm::isa(transferRead.getShapedType())) return WalkResult::advance(); LLVM_DEBUG(DBGS() << "Candidate for hoisting: " diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -162,7 +162,7 @@ SmallVector, 8> indexing; SmallVector outputBuffers; for (OpOperand *outputOperand : linalgOp.getDpsInitOperands()) { - if (!outputOperand->get().getType().isa()) + if (!llvm::isa(outputOperand->get().getType())) continue; indexing.push_back(makeCanonicalAffineApplies( b, loc, linalgOp.getMatchingIndexingMap(outputOperand), @@ -242,7 +242,7 @@ return failure(); // The induction variable is a block argument of the entry block of the // loop operation. - BlockArgument ivVal = iv.dyn_cast(); + BlockArgument ivVal = llvm::dyn_cast(iv); if (!ivVal) return failure(); loopSet.insert(ivVal.getOwner()->getParentOp()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp @@ -44,9 +44,9 @@ auto result = operation->getResult(0); - auto kernelTy = kernel.getType().dyn_cast(); - auto initTy = init.getType().dyn_cast(); - auto resultTy = result.getType().template dyn_cast(); + auto kernelTy = llvm::dyn_cast(kernel.getType()); + auto initTy = llvm::dyn_cast(init.getType()); + auto resultTy = llvm::dyn_cast(result.getType()); if (!kernelTy || !initTy || !resultTy) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -292,9 +292,9 @@ }) .Case([&](ComplexType t) { Value tmp; - if (auto et = t.getElementType().dyn_cast()) + if (auto et = llvm::dyn_cast(t.getElementType())) tmp = b.create(FloatAttr::get(et, 0.0)); - else if (auto et = t.getElementType().cast()) + else if (auto et = llvm::cast(t.getElementType())) tmp = b.create(IntegerAttr::get(et, 0)); return b.create(t, tmp, tmp); }) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -93,7 +93,7 @@ {iterationSpace[dimension].offset, iterationSpace[dimension].size, minSplitPoint}); if (auto attr = remainingSize.dyn_cast()) { - if (attr.cast().getValue().isZero()) + if (llvm::cast(attr).getValue().isZero()) return {op, TilingInterface()}; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -112,8 +112,8 @@ continue; } Type newType = RankedTensorType::get( - newShape, - operand->get().getType().cast().getElementType()); + newShape, llvm::cast(operand->get().getType()) + .getElementType()); Value newInput = b.create( loc, newType, operand->get(), reassociation); newInputs.push_back(newInput); @@ -309,7 +309,7 @@ fillOps.reserve(op.getNumDpsInits()); for (auto it : llvm::zip(op.getDpsInitOperands(), neutralElements)) { Value rankedTensor = std::get<0>(it)->get(); - auto t = rankedTensor.getType().cast(); + auto t = llvm::cast(rankedTensor.getType()); RankedTensorType newT = RankedTensorType::Builder(t).insertDim( reductionDimSize / splitFactor, insertSplitDimension); SmallVector dims = @@ -383,7 +383,8 @@ combinerOps)) { Value reindexedOutput = std::get<0>(it); Value originalOutput = std::get<1>(it)->get(); - auto originalOutputType = originalOutput.getType().cast(); + auto originalOutputType = + llvm::cast(originalOutput.getType()); Operation *combinerOp = std::get<2>(it); AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -49,7 +49,7 @@ if (!v) return false; if (auto attr = v.dyn_cast()) { - IntegerAttr intAttr = attr.dyn_cast(); + IntegerAttr intAttr = llvm::dyn_cast(attr); return intAttr && intAttr.getValue().isZero(); } if (auto cst = v.get().getDefiningOp()) @@ -105,7 +105,7 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b, OpFoldResult value) { if (auto attr = value.dyn_cast()) { - assert(attr.cast().getValue().isStrictlyPositive() && + assert(llvm::cast(attr).getValue().isStrictlyPositive() && "expected strictly positive tile size and divisor"); return; } @@ -587,8 +587,8 @@ SmallVector loops; loops.reserve(ivs.size()); for (auto iv : ivs) { - if (iv.isa()) { - loops.push_back(iv.cast().getOwner()->getParentOp()); + if (llvm::isa(iv)) { + loops.push_back(llvm::cast(iv).getOwner()->getParentOp()); assert(loops.back() && "no owner found for induction variable!"); } else { // TODO: Instead of doing this, try to recover the ops used instead of the @@ -712,7 +712,7 @@ outOffsets[reductionDim] = forallOp.getInductionVars().front(); // TODO: use SubsetExtractOpInterface once it is available. tiledDpsInitOperands.push_back(b.create( - loc, initOperand->get().getType().cast(), + loc, llvm::cast(initOperand->get().getType()), destBbArgs[destNum], outOffsets, sizes, strides)); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -366,7 +366,7 @@ // Then create a new reduction that only reduce the newly added dimension // from the previous op. int64_t intermRank = - partialReduce[0].getType().cast().getRank(); + llvm::cast(partialReduce[0].getType()).getRank(); AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); SmallVector reductionIteratorTypes; SmallVector exprs; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -89,7 +89,7 @@ // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp. OpOperand *currOpOperand = opOperand; while (auto linalgOp = currOpOperand->get().getDefiningOp()) { - OpResult result = currOpOperand->get().cast(); + OpResult result = llvm::cast(currOpOperand->get()); currOpOperand = linalgOp.getDpsInitOperand(result.getResultNumber()); } @@ -133,7 +133,7 @@ // If the size is an attribute add it directly to `paddedShape`. if (en.value().is()) { paddedShape[shapeIdx++] = - en.value().get().dyn_cast().getInt(); + llvm::dyn_cast(en.value().get()).getInt(); LLVM_DEBUG( DBGS() << "------dim is an attr, add it to padded shape, SKIP\n"); continue; @@ -232,7 +232,8 @@ for (const auto &en : llvm::enumerate(paddedOp->getResults())) { Value paddedResult = en.value(); int64_t resultNumber = en.index(); - int64_t rank = paddedResult.getType().cast().getRank(); + int64_t rank = + llvm::cast(paddedResult.getType()).getRank(); SmallVector offsets(rank, rewriter.getIndexAttr(0)); SmallVector sizes; SmallVector strides(rank, rewriter.getIndexAttr(1)); @@ -476,7 +477,7 @@ tensor::PackOp packOp) { // 1. Filter out NYI cases. auto packedTensorType = - packOp->getResultTypes().front().cast(); + llvm::cast(packOp->getResultTypes().front()); if (!packedTensorType.hasStaticShape()) { return rewriter.notifyMatchFailure( packOp, @@ -622,7 +623,8 @@ int64_t packedRank = packedTensorType.getRank(); OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1); - auto destTensorType = unPackOp.getDest().getType().cast(); + auto destTensorType = + llvm::cast(unPackOp.getDest().getType()); if (unPackOp.isLikeUnPad()) { // This unpack is just a plain unpad. // Just extract the slice from the higher ranked tensor. @@ -872,7 +874,7 @@ // Sanity check of the expected transposed tensor type. auto tensorType = permuteShape( - opOperand.get().getType().cast(), permutation); + llvm::cast(opOperand.get().getType()), permutation); (void)tensorType; assert(tensorType == transposedValue.getType() && "expected tensor type mismatch"); @@ -1033,8 +1035,8 @@ PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const { - auto inputShapedType = padOp.getSource().getType().cast(); - auto resultShapedType = padOp.getResult().getType().cast(); + auto inputShapedType = llvm::cast(padOp.getSource().getType()); + auto resultShapedType = llvm::cast(padOp.getResult().getType()); // Bail on non-static shapes. if (!inputShapedType.hasStaticShape()) @@ -1051,7 +1053,7 @@ Operation *definingOp = padValue.getDefiningOp(); if (definingOp && definingOp->getBlock() == &block) return failure(); - if (!definingOp && padValue.cast().getOwner() == &block) + if (!definingOp && llvm::cast(padValue).getOwner() == &block) return failure(); // Create tensor with the padded shape @@ -1117,7 +1119,8 @@ return val; return rewriter .create( - padOp.getLoc(), ofr.get().cast().getInt()) + padOp.getLoc(), + llvm::cast(ofr.get()).getInt()) .getResult(); }; @@ -1497,9 +1500,9 @@ Value kernel = convOp.getInputs().back(); Value output = convOp.getOutputs().front(); - auto inputType = input.getType().dyn_cast(); - auto kernelType = kernel.getType().dyn_cast(); - auto outputType = output.getType().dyn_cast(); + auto inputType = llvm::dyn_cast(input.getType()); + auto kernelType = llvm::dyn_cast(kernel.getType()); + auto outputType = llvm::dyn_cast(output.getType()); auto kernelShape = kernelType.getShape(); auto outputShape = outputType.getShape(); @@ -1621,9 +1624,9 @@ Value kernel = convOp.getInputs().back(); Value output = convOp.getOutputs().front(); - auto inputType = input.getType().dyn_cast(); - auto kernelType = kernel.getType().dyn_cast(); - auto outputType = output.getType().dyn_cast(); + auto inputType = llvm::dyn_cast(input.getType()); + auto kernelType = llvm::dyn_cast(kernel.getType()); + auto outputType = llvm::dyn_cast(output.getType()); auto kernelShape = kernelType.getShape(); auto outputShape = outputType.getShape(); @@ -1689,9 +1692,9 @@ Value kernel = convOp.getInputs().back(); Value output = convOp.getOutputs().front(); - auto inputType = input.getType().dyn_cast(); - auto kernelType = kernel.getType().dyn_cast(); - auto outputType = output.getType().dyn_cast(); + auto inputType = llvm::dyn_cast(input.getType()); + auto kernelType = llvm::dyn_cast(kernel.getType()); + auto outputType = llvm::dyn_cast(output.getType()); auto kernelShape = kernelType.getShape(); auto outputShape = outputType.getShape(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -563,7 +563,7 @@ loc, value, outputOperand->get(), indices, writeMap); } else { // 0-d case is still special: do not invert the reindexing writeMap. - if (!value.getType().isa()) + if (!llvm::isa(value.getType())) value = rewriter.create(loc, vectorType, value); assert(value.getType() == vectorType && "incorrect type"); write = rewriter.create( @@ -864,7 +864,7 @@ targetShape.back() == 1) return VectorMemoryAccessKind::Gather; - auto inputShape = extractOp.getTensor().getType().cast(); + auto inputShape = llvm::cast(extractOp.getTensor().getType()); // 2. Assume that it's a gather load when reading _from_ a tensor for which // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`. @@ -1024,8 +1024,8 @@ const IRMapping &bvm) { Value reduceVec = bvm.lookup(reduceValue); Value outputVec = bvm.lookup(initialValue); - auto reduceType = reduceVec.getType().dyn_cast(); - auto outputType = outputVec.getType().dyn_cast(); + auto reduceType = llvm::dyn_cast(reduceVec.getType()); + auto outputType = llvm::dyn_cast(outputVec.getType()); // Reduce only if needed as the value may already have been reduce for // contraction vectorization. if (!reduceType || @@ -1082,7 +1082,7 @@ // 4 . Check if the operation is a reduction. SmallVector> reductionOperands; for (Value operand : op->getOperands()) { - auto blockArg = operand.dyn_cast(); + auto blockArg = llvm::dyn_cast(operand); if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() || blockArg.getArgNumber() < linalgOp.getNumDpsInputs()) continue; @@ -1107,7 +1107,7 @@ // a. first get the first max ranked shape. SmallVector firstMaxRankedShape; for (Value operand : op->getOperands()) { - auto vt = bvm.lookup(operand).getType().dyn_cast(); + auto vt = llvm::dyn_cast(bvm.lookup(operand).getType()); if (vt && firstMaxRankedShape.size() < vt.getShape().size()) firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end()); } @@ -1230,7 +1230,7 @@ // 3.c. Not all ops support 0-d vectors, extract the scalar for now. // TODO: remove this. - if (readValue.getType().cast().getRank() == 0) + if (llvm::cast(readValue.getType()).getRank() == 0) readValue = rewriter.create(loc, readValue); LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue @@ -1528,8 +1528,8 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter, memref::CopyOp copyOp) { - auto srcType = copyOp.getSource().getType().cast(); - auto dstType = copyOp.getTarget().getType().cast(); + auto srcType = llvm::cast(copyOp.getSource().getType()); + auto dstType = llvm::cast(copyOp.getTarget().getType()); if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) return failure(); @@ -1549,7 +1549,7 @@ Value readValue = rewriter.create( loc, readType, copyOp.getSource(), indices, rewriter.getMultiDimIdentityMap(srcType.getRank())); - if (readValue.getType().cast().getRank() == 0) { + if (llvm::cast(readValue.getType()).getRank() == 0) { readValue = rewriter.create(loc, readValue); readValue = rewriter.create(loc, writeType, readValue); } @@ -1566,7 +1566,7 @@ /// Helper function that retrieves the value of an IntegerAttr. static int64_t getIntFromAttr(Attribute attr) { - return attr.cast().getInt(); + return llvm::cast(attr).getInt(); } /// Given an ArrayRef of OpFoldResults, return a vector of Values. @@ -1836,8 +1836,8 @@ if (hasSameTensorSize(castOp.getSource(), afterTrimming)) return true; - auto t1 = beforePadding.getType().dyn_cast(); - auto t2 = afterTrimming.getType().dyn_cast(); + auto t1 = llvm::dyn_cast(beforePadding.getType()); + auto t2 = llvm::dyn_cast(afterTrimming.getType()); // Only RankedTensorType supported. if (!t1 || !t2) return false; @@ -1946,7 +1946,7 @@ if (!padValue) return failure(); // Dynamic shapes not supported. - if (!padOp.getResult().getType().cast().hasStaticShape()) + if (!llvm::cast(padOp.getResult().getType()).hasStaticShape()) return failure(); // Pad result not used as destination. if (insertOp.getDest() == padOp.getResult()) @@ -2074,7 +2074,7 @@ memref::CopyOp copyOp; for (auto &u : subView.getUses()) { if (auto newCopyOp = dyn_cast(u.getOwner())) { - assert(newCopyOp.getTarget().getType().isa()); + assert(llvm::isa(newCopyOp.getTarget().getType())); if (newCopyOp.getTarget() != subView) continue; if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView})) @@ -2091,7 +2091,7 @@ FillOp maybeFillOp; for (auto &u : viewOrAlloc.getUses()) { if (auto newFillOp = dyn_cast(u.getOwner())) { - assert(newFillOp.output().getType().isa()); + assert(llvm::isa(newFillOp.output().getType())); if (newFillOp.output() != viewOrAlloc) continue; if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView})) @@ -2162,7 +2162,7 @@ return rewriter.notifyMatchFailure(xferOp, "no copy found"); // `out` is the subview copied into that we replace. - assert(copyOp.getTarget().getType().isa()); + assert(llvm::isa(copyOp.getTarget().getType())); Value out = copyOp.getTarget(); // Forward vector.transfer into copy. @@ -2204,7 +2204,7 @@ namespace { bool isCastOfBlockArgument(Operation *op) { return isa(op) && op->getNumOperands() == 1 && - op->getOperand(0).isa(); + llvm::isa(op->getOperand(0)); } bool isSupportedPoolKind(vector::CombiningKind kind) { @@ -2268,9 +2268,9 @@ lhsShaped = linalgOp.getDpsInputOperand(0)->get(); rhsShaped = linalgOp.getDpsInputOperand(1)->get(); resShaped = linalgOp.getDpsInitOperand(0)->get(); - lhsShapedType = lhsShaped.getType().dyn_cast(); - rhsShapedType = rhsShaped.getType().dyn_cast(); - resShapedType = resShaped.getType().dyn_cast(); + lhsShapedType = llvm::dyn_cast(lhsShaped.getType()); + rhsShapedType = llvm::dyn_cast(rhsShaped.getType()); + resShapedType = llvm::dyn_cast(resShaped.getType()); if (!lhsShapedType || !rhsShapedType || !resShapedType) return; // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR @@ -2717,8 +2717,8 @@ /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc, Value lhs, Value rhs, Value res) { - auto rhsTy = rhs.getType().cast(); - auto resTy = res.getType().cast(); + auto rhsTy = llvm::cast(rhs.getType()); + auto resTy = llvm::cast(res.getType()); // TODO(suderman): Change this to use a vector.ima intrinsic. lhs = promote(rewriter, loc, lhs, resTy); @@ -2730,7 +2730,7 @@ if (!lhs || !rhs) return nullptr; - if (resTy.getElementType().isa()) + if (llvm::isa(resTy.getElementType())) return rewriter.create(loc, lhs, rhs, res); auto mul = rewriter.create(loc, lhs, rhs); @@ -2865,13 +2865,13 @@ bool setOperKind(Operation *reduceOp) { int numBlockArguments = llvm::count_if(reduceOp->getOperands(), - [](Value v) { return v.isa(); }); + [](Value v) { return llvm::isa(v); }); switch (numBlockArguments) { case 1: { // Will be convolution if feeder is a MulOp. // Otherwise, if it can be pooling. auto feedValIt = llvm::find_if(reduceOp->getOperands(), [](Value v) { - return !v.isa(); + return !llvm::isa(v); }); Operation *feedOp = (*feedValIt).getDefiningOp(); if (isCastOfBlockArgument(feedOp)) { @@ -2880,7 +2880,7 @@ poolExtOp = feedOp->getName().getIdentifier(); } else if (!(isa(feedOp) && llvm::all_of(feedOp->getOperands(), [](Value v) { - if (v.isa()) + if (llvm::isa(v)) return true; if (Operation *op = v.getDefiningOp()) return isCastOfBlockArgument(op); diff --git a/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp @@ -43,16 +43,16 @@ namespace mlir { namespace linalg { Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim) { - if (val.getType().isa()) + if (llvm::isa(val.getType())) return b.createOrFold(loc, val, dim); - if (val.getType().isa()) + if (llvm::isa(val.getType())) return b.createOrFold(loc, val, dim); llvm_unreachable("Expected MemRefType or TensorType"); } OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim) { - auto shapedType = val.getType().cast(); + auto shapedType = llvm::cast(val.getType()); if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) return createOrFoldDimOp(b, loc, val, dim); return b.getIndexAttr(shapedType.getDimSize(dim)); @@ -60,7 +60,7 @@ SmallVector createDynamicDimensions(OpBuilder &b, Location loc, Value val) { - auto shapedType = val.getType().cast(); + auto shapedType = llvm::cast(val.getType()); assert(shapedType.hasRank() && "`val` must have a static rank"); SmallVector res; res.reserve(shapedType.getRank()); @@ -73,7 +73,7 @@ SmallVector getMixedDimensions(OpBuilder &b, Location loc, Value val) { - auto shapedType = val.getType().cast(); + auto shapedType = llvm::cast(val.getType()); assert(shapedType.hasRank() && "`val` must have a static rank"); SmallVector dynamicDims = createDynamicDimensions(b, loc, val); return getMixedValues(shapedType.getShape(), dynamicDims, b); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -281,7 +281,7 @@ auto linalgOp = current.getDefiningOp(); if (!linalgOp) break; - OpResult opResult = current.cast(); + OpResult opResult = llvm::cast(current); current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get(); } auto padOp = current ? current.getDefiningOp() : nullptr; @@ -331,7 +331,7 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor, Value outputTensor, ArrayRef transposeVector) { - auto resultTensorType = outputTensor.getType().cast(); + auto resultTensorType = llvm::cast(outputTensor.getType()); Type elementType = resultTensorType.getElementType(); assert(isPermutationVector(transposeVector) && @@ -366,9 +366,9 @@ } GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) { - auto memrefTypeTo = to.getType().cast(); + auto memrefTypeTo = llvm::cast(to.getType()); #ifndef NDEBUG - auto memrefTypeFrom = from.getType().cast(); + auto memrefTypeFrom = llvm::cast(from.getType()); assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() && "`from` and `to` memref must have the same rank"); #endif // NDEBUG @@ -650,7 +650,7 @@ static Value materializeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, const SliceParameters &sliceParams) { - auto shapedType = valueToTile.getType().dyn_cast(); + auto shapedType = llvm::dyn_cast(valueToTile.getType()); auto *sliceOp = TypeSwitch(shapedType) .Case([&](MemRefType) { return builder.create( @@ -685,7 +685,7 @@ ArrayRef lbs, ArrayRef ubs, ArrayRef subShapeSizes, bool omitPartialTileCheck) { - auto shapedType = valueToTile.getType().dyn_cast(); + auto shapedType = llvm::dyn_cast(valueToTile.getType()); assert(shapedType && "only shaped types can be tiled"); ArrayRef shape = shapedType.getShape(); int64_t rank = shapedType.getRank(); @@ -889,8 +889,9 @@ // subdomains explicit. Type operandType = opOperand.get().getType(); - if (!isTiled(map, tileSizes) && !(operandType.isa() && - linalgOp.isDpsInit(&opOperand))) { + if (!isTiled(map, tileSizes) && + !(llvm::isa(operandType) && + linalgOp.isDpsInit(&opOperand))) { allSliceParams.push_back(std::nullopt); LLVM_DEBUG(llvm::dbgs() << ": not tiled: use shape: " << operandType << "\n"); @@ -971,7 +972,7 @@ auto size = it.value(); curr.push_back(dim); auto attr = size.dyn_cast(); - if (attr && attr.cast().getInt() == 1) + if (attr && llvm::cast(attr).getInt() == 1) continue; reassociation.emplace_back(ReassociationIndices{}); std::swap(reassociation.back(), curr); @@ -989,7 +990,7 @@ // Builder only used as helper for attribute creation. OpBuilder b(op->getContext()); Type resultType = op->getResult(0).getType(); - if (auto floatType = resultType.dyn_cast()) { + if (auto floatType = llvm::dyn_cast(resultType)) { const llvm::fltSemantics &semantic = floatType.getFloatSemantics(); if (isa(op)) return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic)); diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp --- a/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp +++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp @@ -28,7 +28,7 @@ using OpAsmDialectInterface::OpAsmDialectInterface; AliasResult getAlias(Attribute attr, raw_ostream &os) const override { - if (attr.isa()) { + if (llvm::isa(attr)) { os << "extern"; return AliasResult::OverridableAlias; } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp @@ -23,7 +23,7 @@ //===----------------------------------------------------------------------===// static bool isSupportedElementType(Type type) { - return type.isa() || + return llvm::isa(type) || OpBuilder(type.getContext()).getZeroAttr(type); } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -90,7 +90,7 @@ for (OpOperand &operand : op->getOpOperands()) { auto cast = operand.get().getDefiningOp(); if (cast && operand.get() != inner && - !cast.getOperand().getType().isa()) { + !llvm::isa(cast.getOperand().getType())) { operand.set(cast.getOperand()); folded = true; } @@ -101,16 +101,16 @@ /// Return an unranked/ranked tensor type for the given unranked/ranked memref /// type. Type mlir::memref::getTensorTypeFromMemRefType(Type type) { - if (auto memref = type.dyn_cast()) + if (auto memref = llvm::dyn_cast(type)) return RankedTensorType::get(memref.getShape(), memref.getElementType()); - if (auto memref = type.dyn_cast()) + if (auto memref = llvm::dyn_cast(type)) return UnrankedTensorType::get(memref.getElementType()); return NoneType::get(type.getContext()); } SmallVector memref::getMixedSizes(OpBuilder &builder, Location loc, Value value) { - auto memrefType = value.getType().cast(); + auto memrefType = llvm::cast(value.getType()); SmallVector result; for (int64_t i = 0; i < memrefType.getRank(); ++i) { if (memrefType.isDynamicDim(i)) { @@ -180,7 +180,7 @@ // values, hence we recreate the attribute even when it is already static // to make sure the type is consistent. ofr = builder.getIndexAttr( - ofr.get().cast().getInt()); + llvm::cast(ofr.get()).getInt()); continue; } std::optional maybeConstant = @@ -241,7 +241,7 @@ static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { static_assert(llvm::is_one_of::value, "applies to only alloc or alloca"); - auto memRefType = op.getResult().getType().template dyn_cast(); + auto memRefType = llvm::dyn_cast(op.getResult().getType()); if (!memRefType) return op.emitOpError("result must be a memref"); @@ -378,7 +378,7 @@ //===----------------------------------------------------------------------===// LogicalResult ReallocOp::verify() { - auto sourceType = getOperand(0).getType().cast(); + auto sourceType = llvm::cast(getOperand(0).getType()); MemRefType resultType = getType(); // The source memref should have identity layout (or none). @@ -691,8 +691,9 @@ /// consumer %0 ... : memref(16 * i + j)>> /// ``` bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { - MemRefType sourceType = castOp.getSource().getType().dyn_cast(); - MemRefType resultType = castOp.getType().dyn_cast(); + MemRefType sourceType = + llvm::dyn_cast(castOp.getSource().getType()); + MemRefType resultType = llvm::dyn_cast(castOp.getType()); // Requires ranked MemRefType. if (!sourceType || !resultType) @@ -743,11 +744,11 @@ if (inputs.size() != 1 || outputs.size() != 1) return false; Type a = inputs.front(), b = outputs.front(); - auto aT = a.dyn_cast(); - auto bT = b.dyn_cast(); + auto aT = llvm::dyn_cast(a); + auto bT = llvm::dyn_cast(b); - auto uaT = a.dyn_cast(); - auto ubT = b.dyn_cast(); + auto uaT = llvm::dyn_cast(a); + auto ubT = llvm::dyn_cast(b); if (aT && bT) { if (aT.getElementType() != bT.getElementType()) @@ -831,8 +832,8 @@ // Check source. if (auto castOp = copyOp.getSource().getDefiningOp()) { - auto fromType = castOp.getSource().getType().dyn_cast(); - auto toType = castOp.getSource().getType().dyn_cast(); + auto fromType = llvm::dyn_cast(castOp.getSource().getType()); + auto toType = llvm::dyn_cast(castOp.getSource().getType()); if (fromType && toType) { if (fromType.getShape() == toType.getShape() && @@ -847,8 +848,8 @@ // Check target. if (auto castOp = copyOp.getTarget().getDefiningOp()) { - auto fromType = castOp.getSource().getType().dyn_cast(); - auto toType = castOp.getSource().getType().dyn_cast(); + auto fromType = llvm::dyn_cast(castOp.getSource().getType()); + auto toType = llvm::dyn_cast(castOp.getSource().getType()); if (fromType && toType) { if (fromType.getShape() == toType.getShape() && @@ -970,7 +971,7 @@ for (const auto &dim : llvm::enumerate(sizes)) if (auto attr = dim.value().dyn_cast()) - if (attr.cast().getInt() == 1) + if (llvm::cast(attr).getInt() == 1) unusedDims.set(dim.index()); // Early exit for the case where the number of unused dims matches the number @@ -1046,7 +1047,7 @@ return {}; // Folding for unranked types (UnrankedMemRefType) is not supported. - auto memrefType = getSource().getType().dyn_cast(); + auto memrefType = llvm::dyn_cast(getSource().getType()); if (!memrefType) return {}; @@ -1256,7 +1257,7 @@ // Check types of operands. The order of these calls is important: the later // calls rely on some type properties to compute the operand position. // 1. Source memref. - if (!getSrcMemRef().getType().isa()) + if (!llvm::isa(getSrcMemRef().getType())) return emitOpError("expected source to be of memref type"); if (numOperands < getSrcMemRefRank() + 4) return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 @@ -1267,7 +1268,7 @@ return emitOpError("expected source indices to be of index type"); // 2. Destination memref. - if (!getDstMemRef().getType().isa()) + if (!llvm::isa(getDstMemRef().getType())) return emitOpError("expected destination to be of memref type"); unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; if (numOperands < numExpectedOperands) @@ -1283,7 +1284,7 @@ return emitOpError("expected num elements to be of index type"); // 4. Tag memref. - if (!getTagMemRef().getType().isa()) + if (!llvm::isa(getTagMemRef().getType())) return emitOpError("expected tag to be of memref type"); numExpectedOperands += getTagMemRefRank(); if (numOperands < numExpectedOperands) @@ -1359,7 +1360,8 @@ SmallVectorImpl &inferredReturnTypes) { ExtractStridedMetadataOpAdaptor extractAdaptor(operands, attributes, properties); - auto sourceType = extractAdaptor.getSource().getType().dyn_cast(); + auto sourceType = + llvm::dyn_cast(extractAdaptor.getSource().getType()); if (!sourceType) return failure(); @@ -1409,8 +1411,7 @@ "The constified value should be either unchanged (i.e., == result) " "or a constant"); Value constantVal = rewriter.create( - loc, maybeConstant.template get() - .template cast() + loc, llvm::cast(maybeConstant.template get()) .getInt()); for (Operation *op : llvm::make_early_inc_range(result.getUsers())) { // updateRootInplace: lambda cannot capture structured bindings in C++17 @@ -1470,7 +1471,7 @@ result.addOperands(memref); result.addOperands(ivs); - if (auto memrefType = memref.getType().dyn_cast()) { + if (auto memrefType = llvm::dyn_cast(memref.getType())) { Type elementType = memrefType.getElementType(); result.addTypes(elementType); @@ -1519,7 +1520,7 @@ if (parser.parseRegion(*body, {}) || parser.parseOptionalAttrDict(result.attributes)) return failure(); - result.types.push_back(memrefType.cast().getElementType()); + result.types.push_back(llvm::cast(memrefType).getElementType()); return success(); } @@ -1567,7 +1568,7 @@ if (parser.parseType(type)) return failure(); - auto memrefType = type.dyn_cast(); + auto memrefType = llvm::dyn_cast(type); if (!memrefType || !memrefType.hasStaticShape()) return parser.emitError(parser.getNameLoc()) << "type should be static shaped memref, but got " << type; @@ -1584,14 +1585,14 @@ Type tensorType = getTensorTypeFromMemRefType(memrefType); if (parser.parseAttribute(initialValue, tensorType)) return failure(); - if (!initialValue.isa()) + if (!llvm::isa(initialValue)) return parser.emitError(parser.getNameLoc()) << "initial value should be a unit or elements attribute"; return success(); } LogicalResult GlobalOp::verify() { - auto memrefType = getType().dyn_cast(); + auto memrefType = llvm::dyn_cast(getType()); if (!memrefType || !memrefType.hasStaticShape()) return emitOpError("type should be static shaped memref, but got ") << getType(); @@ -1600,14 +1601,14 @@ // an elements attribute. if (getInitialValue().has_value()) { Attribute initValue = getInitialValue().value(); - if (!initValue.isa() && !initValue.isa()) + if (!llvm::isa(initValue) && !llvm::isa(initValue)) return emitOpError("initial value should be a unit or elements " "attribute, but got ") << initValue; // Check that the type of the initial value is compatible with the type of // the global variable. - if (auto elementsAttr = initValue.dyn_cast()) { + if (auto elementsAttr = llvm::dyn_cast(initValue)) { Type initType = elementsAttr.getType(); Type tensorType = getTensorTypeFromMemRefType(memrefType); if (initType != tensorType) @@ -1631,7 +1632,7 @@ ElementsAttr GlobalOp::getConstantInitValue() { auto initVal = getInitialValue(); if (getConstant() && initVal.has_value()) - return initVal.value().cast(); + return llvm::cast(initVal.value()); return {}; } @@ -1687,11 +1688,11 @@ if (inputs.size() != 1 || outputs.size() != 1) return false; Type a = inputs.front(), b = outputs.front(); - auto aT = a.dyn_cast(); - auto bT = b.dyn_cast(); + auto aT = llvm::dyn_cast(a); + auto bT = llvm::dyn_cast(b); - auto uaT = a.dyn_cast(); - auto ubT = b.dyn_cast(); + auto uaT = llvm::dyn_cast(a); + auto ubT = llvm::dyn_cast(b); if (aT && bT) { if (aT.getElementType() != bT.getElementType()) @@ -1794,7 +1795,7 @@ OpFoldResult RankOp::fold(FoldAdaptor adaptor) { // Constant fold rank when the rank of the operand is known. auto type = getOperand().getType(); - auto shapedType = type.dyn_cast(); + auto shapedType = llvm::dyn_cast(type); if (shapedType && shapedType.hasRank()) return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); return IntegerAttr(); @@ -1861,8 +1862,8 @@ // completed automatically, like we have for subview and extract_slice. LogicalResult ReinterpretCastOp::verify() { // The source and result memrefs should be in the same memory space. - auto srcType = getSource().getType().cast(); - auto resultType = getType().cast(); + auto srcType = llvm::cast(getSource().getType()); + auto resultType = llvm::cast(getType()); if (srcType.getMemorySpace() != resultType.getMemorySpace()) return emitError("different memory spaces specified for source type ") << srcType << " and result memref type " << resultType; @@ -2250,7 +2251,7 @@ ArrayRef resultShape, Value src, ArrayRef reassociation) { // Only ranked memref source values are supported. - auto srcType = src.getType().cast(); + auto srcType = llvm::cast(src.getType()); FailureOr resultType = ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation); // Failure of this assertion usually indicates a problem with the source @@ -2406,7 +2407,7 @@ void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, ArrayRef attrs) { - auto srcType = src.getType().cast(); + auto srcType = llvm::cast(src.getType()); MemRefType resultType = CollapseShapeOp::computeCollapsedType(srcType, reassociation); build(b, result, resultType, src, attrs); @@ -2473,7 +2474,7 @@ return failure(); Type newResultType = CollapseShapeOp::computeCollapsedType( - cast.getOperand().getType().cast(), + llvm::cast(cast.getOperand().getType()), op.getReassociationIndices()); if (newResultType == op.getResultType()) { @@ -2518,18 +2519,20 @@ Type operandType = getSource().getType(); Type resultType = getResult().getType(); - Type operandElementType = operandType.cast().getElementType(); - Type resultElementType = resultType.cast().getElementType(); + Type operandElementType = + llvm::cast(operandType).getElementType(); + Type resultElementType = llvm::cast(resultType).getElementType(); if (operandElementType != resultElementType) return emitOpError("element types of source and destination memref " "types should be the same"); - if (auto operandMemRefType = operandType.dyn_cast()) + if (auto operandMemRefType = llvm::dyn_cast(operandType)) if (!operandMemRefType.getLayout().isIdentity()) return emitOpError("source memref type should have identity affine map"); - int64_t shapeSize = getShape().getType().cast().getDimSize(0); - auto resultMemRefType = resultType.dyn_cast(); + int64_t shapeSize = + llvm::cast(getShape().getType()).getDimSize(0); + auto resultMemRefType = llvm::dyn_cast(resultType); if (resultMemRefType) { if (!resultMemRefType.getLayout().isIdentity()) return emitOpError("result memref type should have identity affine map"); @@ -2634,9 +2637,8 @@ ArrayRef offsets, ArrayRef sizes, ArrayRef strides) { - auto inferredType = - inferResultType(sourceRankedTensorType, offsets, sizes, strides) - .cast(); + auto inferredType = llvm::cast( + inferResultType(sourceRankedTensorType, offsets, sizes, strides)); assert(inferredType.getRank() >= static_cast(resultShape.size()) && "expected "); if (inferredType.getRank() == static_cast(resultShape.size())) @@ -2648,7 +2650,7 @@ assert(dimsToProject.has_value() && "invalid rank reduction"); // Compute the layout and result type. - auto inferredLayout = inferredType.getLayout().cast(); + auto inferredLayout = llvm::cast(inferredType.getLayout()); SmallVector rankReducedStrides; rankReducedStrides.reserve(resultShape.size()); for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) { @@ -2690,12 +2692,11 @@ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - auto sourceMemRefType = source.getType().cast(); + auto sourceMemRefType = llvm::cast(source.getType()); // Structuring implementation this way avoids duplication between builders. if (!resultType) { - resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, - staticSizes, staticStrides) - .cast(); + resultType = llvm::cast(SubViewOp::inferResultType( + sourceMemRefType, staticOffsets, staticSizes, staticStrides)); } build(b, result, resultType, source, dynamicOffsets, dynamicSizes, dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets), @@ -2824,7 +2825,7 @@ template static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, OpTy op, Type expectedType) { - auto memrefType = expectedType.cast(); + auto memrefType = llvm::cast(expectedType); switch (result) { case SliceVerificationResult::Success: return success(); @@ -2867,7 +2868,7 @@ auto expectedType = SubViewOp::inferResultType( baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()); - auto result = isRankReducedMemRefType(expectedType.cast(), + auto result = isRankReducedMemRefType(llvm::cast(expectedType), subViewType, getMixedSizes()); return produceSubViewErrorMsg(result, *this, expectedType); } @@ -2917,9 +2918,8 @@ MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef mixedOffsets, ArrayRef mixedSizes, ArrayRef mixedStrides) { - auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets, - mixedSizes, mixedStrides) - .cast(); + auto nonRankReducedType = llvm::cast(SubViewOp::inferResultType( + sourceType, mixedOffsets, mixedSizes, mixedStrides)); std::optional unusedDims = computeMemRefRankReductionMask(currentSourceType, currentResultType, mixedSizes); @@ -2927,7 +2927,7 @@ if (!unusedDims) return nullptr; - auto layout = nonRankReducedType.getLayout().cast(); + auto layout = llvm::cast(nonRankReducedType.getLayout()); SmallVector shape, strides; unsigned numDimsAfterReduction = nonRankReducedType.getRank() - unusedDims->count(); @@ -2962,14 +2962,14 @@ Value mlir::memref::createCanonicalRankReducingSubViewOp( OpBuilder &b, Location loc, Value memref, ArrayRef targetShape) { - auto memrefType = memref.getType().cast(); + auto memrefType = llvm::cast(memref.getType()); unsigned rank = memrefType.getRank(); SmallVector offsets(rank, b.getIndexAttr(0)); SmallVector sizes = getMixedSizes(b, loc, memref); SmallVector strides(rank, b.getIndexAttr(1)); - auto targetType = SubViewOp::inferRankReducedResultType( - targetShape, memrefType, offsets, sizes, strides) - .cast(); + auto targetType = + llvm::cast(SubViewOp::inferRankReducedResultType( + targetShape, memrefType, offsets, sizes, strides)); return b.createOrFold(loc, targetType, memref, offsets, sizes, strides); } @@ -2977,7 +2977,7 @@ FailureOr SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value, ArrayRef desiredShape) { - auto sourceMemrefType = value.getType().dyn_cast(); + auto sourceMemrefType = llvm::dyn_cast(value.getType()); assert(sourceMemrefType && "not a ranked memref type"); auto sourceShape = sourceMemrefType.getShape(); if (sourceShape.equals(desiredShape)) @@ -3069,7 +3069,7 @@ // if the operation is rank-reducing. auto resultType = getCanonicalSubViewResultType( subViewOp.getType(), subViewOp.getSourceType(), - castOp.getSource().getType().cast(), + llvm::cast(castOp.getSource().getType()), subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), subViewOp.getMixedStrides()); if (!resultType) @@ -3134,8 +3134,8 @@ } OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) { - auto resultShapedType = getResult().getType().cast(); - auto sourceShapedType = getSource().getType().cast(); + auto resultShapedType = llvm::cast(getResult().getType()); + auto sourceShapedType = llvm::cast(getSource().getType()); if (resultShapedType.hasStaticShape() && resultShapedType == sourceShapedType) { @@ -3201,7 +3201,7 @@ auto permutationMap = permutation.getValue(); assert(permutationMap); - auto memRefType = in.getType().cast(); + auto memRefType = llvm::cast(in.getType()); // Compute result type. MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); @@ -3239,8 +3239,8 @@ if (getPermutation().getNumDims() != getIn().getType().getRank()) return emitOpError("expected a permutation map of same rank as the input"); - auto srcType = getIn().getType().cast(); - auto dstType = getType().cast(); + auto srcType = llvm::cast(getIn().getType()); + auto dstType = llvm::cast(getType()); auto transposedType = inferTransposeResultType(srcType, getPermutation()); if (dstType != transposedType) return emitOpError("output type ") @@ -3264,7 +3264,7 @@ } LogicalResult ViewOp::verify() { - auto baseType = getOperand(0).getType().cast(); + auto baseType = llvm::cast(getOperand(0).getType()); auto viewType = getType(); // The base memref should have identity layout map (or none). @@ -3401,7 +3401,7 @@ case arith::AtomicRMWKind::maxf: case arith::AtomicRMWKind::minf: case arith::AtomicRMWKind::mulf: - if (!getValue().getType().isa()) + if (!llvm::isa(getValue().getType())) return emitOpError() << "with kind '" << arith::stringifyAtomicRMWKind(getKind()) << "' expects a floating-point type"; @@ -3414,7 +3414,7 @@ case arith::AtomicRMWKind::muli: case arith::AtomicRMWKind::ori: case arith::AtomicRMWKind::andi: - if (!getValue().getType().isa()) + if (!llvm::isa(getValue().getType())) return emitOpError() << "with kind '" << arith::stringifyAtomicRMWKind(getKind()) << "' expects an integer type"; diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp @@ -37,8 +37,8 @@ auto castOp = cast(op); assert(value == castOp.getResult() && "invalid value"); - if (castOp.getResult().getType().isa() && - castOp.getSource().getType().isa()) { + if (llvm::isa(castOp.getResult().getType()) && + llvm::isa(castOp.getSource().getType())) { cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim); } } @@ -79,7 +79,7 @@ auto rankOp = cast(op); assert(value == rankOp.getResult() && "invalid value"); - auto memrefType = rankOp.getMemref().getType().dyn_cast(); + auto memrefType = llvm::dyn_cast(rankOp.getMemref().getType()); if (!memrefType) return; cstr.bound(value) == memrefType.getRank(); diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -69,7 +69,7 @@ results.push_back(*newBuffer); } - transformResults.set(getResult().cast(), results); + transformResults.set(llvm::cast(getResult()), results); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp @@ -57,7 +57,7 @@ // always 1. if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) { Attribute attr = valueOrAttr.dyn_cast(); - return attr && attr.cast().getInt() == 1; + return attr && llvm::cast(attr).getInt() == 1; })) { strides = SmallVector(sourceOp.getMixedStrides().size(), rewriter.getI64IntegerAttr(1)); @@ -93,8 +93,8 @@ // If both offsets are static we can simply calculate the combined // offset statically. offsets.push_back(rewriter.getI64IntegerAttr( - opOffsetAttr.cast().getInt() + - sourceOffsetAttr.cast().getInt())); + llvm::cast(opOffsetAttr).getInt() + + llvm::cast(sourceOffsetAttr).getInt())); } else { // When either offset is dynamic, we must emit an additional affine // transformation to add the two offsets together dynamically. @@ -102,7 +102,7 @@ SmallVector affineApplyOperands; for (auto valueOrAttr : {opOffset, sourceOffset}) { if (auto attr = valueOrAttr.dyn_cast()) { - expr = expr + attr.cast().getInt(); + expr = expr + llvm::cast(attr).getInt(); } else { expr = expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp @@ -149,7 +149,7 @@ arith::WideIntEmulationConverter &typeConverter) { typeConverter.addConversion( [&typeConverter](MemRefType ty) -> std::optional { - auto intTy = ty.getElementType().dyn_cast(); + auto intTy = llvm::dyn_cast(ty.getElementType()); if (!intTy) return ty; diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp @@ -89,11 +89,11 @@ LogicalResult matchAndRewrite(memref::ReshapeOp op, PatternRewriter &rewriter) const final { - auto shapeType = op.getShape().getType().cast(); + auto shapeType = llvm::cast(op.getShape().getType()); if (!shapeType.hasStaticShape()) return failure(); - int64_t rank = shapeType.cast().getDimSize(0); + int64_t rank = llvm::cast(shapeType).getDimSize(0); SmallVector sizes, strides; sizes.resize(rank); strides.resize(rank); @@ -106,7 +106,7 @@ if (op.getType().isDynamicDim(i)) { Value index = rewriter.create(loc, i); size = rewriter.create(loc, op.getShape(), index); - if (!size.getType().isa()) + if (!llvm::isa(size.getType())) size = rewriter.create( loc, rewriter.getIndexType(), size); sizes[i] = size; @@ -141,7 +141,7 @@ op.getKind() != arith::AtomicRMWKind::minf; }); target.addDynamicallyLegalOp([](memref::ReshapeOp op) { - return !op.getShape().getType().cast().hasStaticShape(); + return !llvm::cast(op.getShape().getType()).hasStaticShape(); }); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -62,7 +62,7 @@ // Build a plain extract_strided_metadata(memref) from subview(memref). Location origLoc = subview.getLoc(); Value source = subview.getSource(); - auto sourceType = source.getType().cast(); + auto sourceType = llvm::cast(source.getType()); unsigned sourceRank = sourceType.getRank(); auto newExtractStridedMetadata = @@ -115,7 +115,7 @@ // The final result is . // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all // the values. - auto subType = subview.getType().cast(); + auto subType = llvm::cast(subview.getType()); unsigned subRank = subType.getRank(); // The sizes of the final type are defined directly by the input sizes of @@ -338,7 +338,7 @@ // Collect the statically known information about the original stride. Value source = expandShape.getSrc(); - auto sourceType = source.getType().cast(); + auto sourceType = llvm::cast(source.getType()); auto [strides, offset] = getStridesAndOffset(sourceType); OpFoldResult origStride = ShapedType::isDynamic(strides[groupId]) @@ -358,10 +358,10 @@ AffineExpr s0 = builder.getAffineSymbolExpr(0); AffineExpr s1 = builder.getAffineSymbolExpr(1); for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) { - int64_t baseExpandedStride = expandedStrides[doneStrideIdx] - .get() - .cast() - .getInt(); + int64_t baseExpandedStride = + llvm::cast( + expandedStrides[doneStrideIdx].get()) + .getInt(); expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( builder, expandShape.getLoc(), (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1, @@ -372,10 +372,9 @@ // Now apply the origStride to the remaining dimensions. AffineExpr s0 = builder.getAffineSymbolExpr(0); for (; doneStrideIdx < groupSize; ++doneStrideIdx) { - int64_t baseExpandedStride = expandedStrides[doneStrideIdx] - .get() - .cast() - .getInt(); + int64_t baseExpandedStride = + llvm::cast(expandedStrides[doneStrideIdx].get()) + .getInt(); expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride}); } @@ -445,7 +444,7 @@ // Build the affine expr of the product of the original sizes involved in that // group. Value source = collapseShape.getSrc(); - auto sourceType = source.getType().cast(); + auto sourceType = llvm::cast(source.getType()); SmallVector reassocGroup = collapseShape.getReassociationIndices()[groupId]; @@ -479,7 +478,7 @@ "Reassociation group should have at least one dimension"); Value source = collapseShape.getSrc(); - auto sourceType = source.getType().cast(); + auto sourceType = llvm::cast(source.getType()); auto [strides, offset] = getStridesAndOffset(sourceType); @@ -562,7 +561,7 @@ // extract_strided_metadata(reassociative_reshape_like(memref)). Location origLoc = reshape.getLoc(); Value source = reshape.getSrc(); - auto sourceType = source.getType().cast(); + auto sourceType = llvm::cast(source.getType()); unsigned sourceRank = sourceType.getRank(); auto newExtractStridedMetadata = @@ -650,8 +649,7 @@ if (!allocLikeOp) return failure(); - auto memRefType = - allocLikeOp.getResult().getType().template cast(); + auto memRefType = llvm::cast(allocLikeOp.getResult().getType()); if (!memRefType.getLayout().isIdentity()) return rewriter.notifyMatchFailure( allocLikeOp, "alloc-like operations should have been normalized"); @@ -688,7 +686,7 @@ SmallVector results; results.reserve(rank * 2 + 2); - auto baseBufferType = op.getBaseBuffer().getType().cast(); + auto baseBufferType = llvm::cast(op.getBaseBuffer().getType()); int64_t offset = 0; if (allocLikeOp.getType() == baseBufferType) results.push_back(allocLikeOp); @@ -737,7 +735,7 @@ if (!getGlobalOp) return failure(); - auto memRefType = getGlobalOp.getResult().getType().cast(); + auto memRefType = llvm::cast(getGlobalOp.getResult().getType()); if (!memRefType.getLayout().isIdentity()) { return rewriter.notifyMatchFailure( getGlobalOp, @@ -759,7 +757,7 @@ SmallVector results; results.reserve(rank * 2 + 2); - auto baseBufferType = op.getBaseBuffer().getType().cast(); + auto baseBufferType = llvm::cast(op.getBaseBuffer().getType()); int64_t offset = 0; if (getGlobalOp.getType() == baseBufferType) results.push_back(getGlobalOp); @@ -839,7 +837,7 @@ reinterpretCastOp, "reinterpret_cast source's type is incompatible"); auto memrefType = - reinterpretCastOp.getResult().getType().cast(); + llvm::cast(reinterpretCastOp.getResult().getType()); unsigned rank = memrefType.getRank(); SmallVector results; results.resize_for_overwrite(rank * 2 + 2); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp @@ -120,7 +120,7 @@ static FailureOr getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) { Value src = transferLikeOp.getSource(); - if (src.getType().isa()) + if (llvm::isa(src.getType())) return src; return failure(); } @@ -240,7 +240,7 @@ return rewriter.notifyMatchFailure(loadStoreLikeOp, "source is not a memref"); Value srcMemRef = *failureOrSrcMemRef; - auto ldStTy = srcMemRef.getType().cast(); + auto ldStTy = llvm::cast(srcMemRef.getType()); unsigned loadStoreRank = ldStTy.getRank(); // Don't waste compile time if there is nothing to rewrite. if (loadStoreRank == 0) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -148,7 +148,8 @@ if (collapseShapeOp.getReassociationIndices().empty()) { auto zeroAffineMap = rewriter.getConstantAffineMap(0); int64_t srcRank = - collapseShapeOp.getViewSource().getType().cast().getRank(); + llvm::cast(collapseShapeOp.getViewSource().getType()) + .getRank(); for (int64_t i = 0; i < srcRank; i++) { OpFoldResult ofr = affine::makeComposedFoldedAffineApply( rewriter, loc, zeroAffineMap, dynamicIndices); diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp @@ -72,10 +72,9 @@ OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); auto newResultType = - SubViewOp::inferRankReducedResultType( + llvm::cast(SubViewOp::inferRankReducedResultType( op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(), - op.getMixedSizes(), op.getMixedStrides()) - .cast(); + op.getMixedSizes(), op.getMixedStrides())); Value newSubview = rewriter.create( op.getLoc(), newResultType, conversionOp.getOperand(0), op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -61,11 +61,11 @@ OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(subviewUse); Type newType = memref::SubViewOp::inferRankReducedResultType( - subviewUse.getType().getShape(), val.getType().cast(), + subviewUse.getType().getShape(), llvm::cast(val.getType()), subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), subviewUse.getStaticStrides()); Value newSubview = rewriter.create( - subviewUse->getLoc(), newType.cast(), val, + subviewUse->getLoc(), llvm::cast(newType), val, subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), subviewUse.getMixedStrides()); @@ -209,9 +209,9 @@ for (int64_t i = 0, e = originalShape.size(); i != e; ++i) sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]); // Strides is [1, 1 ... 1 ]. - auto dstMemref = memref::SubViewOp::inferRankReducedResultType( - originalShape, mbMemRefType, offsets, sizes, strides) - .cast(); + auto dstMemref = + llvm::cast(memref::SubViewOp::inferRankReducedResultType( + originalShape, mbMemRefType, offsets, sizes, strides)); Value subview = rewriter.create(loc, dstMemref, mbAlloc, offsets, sizes, strides); LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n"); diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -180,7 +180,7 @@ llvm::seq(0, callOp.getNumResults())) { Value oldMemRef = callOp.getResult(resIndex); if (auto oldMemRefType = - oldMemRef.getType().dyn_cast()) + llvm::dyn_cast(oldMemRef.getType())) if (!oldMemRefType.getLayout().isIdentity() && !isMemRefNormalizable(oldMemRef.getUsers())) return WalkResult::interrupt(); @@ -192,7 +192,7 @@ for (unsigned argIndex : llvm::seq(0, funcOp.getNumArguments())) { BlockArgument oldMemRef = funcOp.getArgument(argIndex); - if (auto oldMemRefType = oldMemRef.getType().dyn_cast()) + if (auto oldMemRefType = llvm::dyn_cast(oldMemRef.getType())) if (!oldMemRefType.getLayout().isIdentity() && !isMemRefNormalizable(oldMemRef.getUsers())) return false; @@ -226,7 +226,7 @@ funcOp.walk([&](func::ReturnOp returnOp) { for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) { Type opType = operandEn.value().getType(); - MemRefType memrefType = opType.dyn_cast(); + MemRefType memrefType = llvm::dyn_cast(opType); // If type is not memref or if the memref type is same as that in // function's return signature then no update is required. if (!memrefType || memrefType == resultTypes[operandEn.index()]) @@ -283,8 +283,9 @@ // need not perform any use replacement here. if (oldResult.getType() == newResult.getType()) continue; - AffineMap layoutMap = - oldResult.getType().cast().getLayout().getAffineMap(); + AffineMap layoutMap = llvm::cast(oldResult.getType()) + .getLayout() + .getAffineMap(); if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult, /*extraIndices=*/{}, /*indexRemap=*/layoutMap, @@ -358,7 +359,7 @@ for (unsigned argIndex : llvm::seq(0, functionType.getNumInputs())) { Type argType = functionType.getInput(argIndex); - MemRefType memrefType = argType.dyn_cast(); + MemRefType memrefType = llvm::dyn_cast(argType); // Check whether argument is of MemRef type. Any other argument type can // simply be part of the final function signature. if (!memrefType) { @@ -422,11 +423,13 @@ // Replace all uses of the old memrefs. Value oldMemRef = op->getResult(resIndex); Value newMemRef = newOp->getResult(resIndex); - MemRefType oldMemRefType = oldMemRef.getType().dyn_cast(); + MemRefType oldMemRefType = + llvm::dyn_cast(oldMemRef.getType()); // Check whether the operation result is MemRef type. if (!oldMemRefType) continue; - MemRefType newMemRefType = newMemRef.getType().cast(); + MemRefType newMemRefType = + llvm::cast(newMemRef.getType()); if (oldMemRefType == newMemRefType) continue; // TODO: Assume single layout map. Multiple maps not supported. @@ -466,7 +469,7 @@ for (unsigned resIndex : llvm::seq(0, functionType.getNumResults())) { Type resType = functionType.getResult(resIndex); - MemRefType memrefType = resType.dyn_cast(); + MemRefType memrefType = llvm::dyn_cast(resType); // Check whether result is of MemRef type. Any other argument type can // simply be part of the final function signature. if (!memrefType) { @@ -507,7 +510,7 @@ bool resultTypeNormalized = false; for (unsigned resIndex : llvm::seq(0, oldOp->getNumResults())) { auto resultType = oldOp->getResult(resIndex).getType(); - MemRefType memrefType = resultType.dyn_cast(); + MemRefType memrefType = llvm::dyn_cast(resultType); // Check whether the operation result is MemRef type. if (!memrefType) { resultTypes.push_back(resultType); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -40,7 +40,7 @@ LogicalResult matchAndRewrite(OpTy dimOp, PatternRewriter &rewriter) const override { - OpResult dimValue = dimOp.getSource().template dyn_cast(); + OpResult dimValue = llvm::dyn_cast(dimOp.getSource()); if (!dimValue) return failure(); auto shapedTypeOp = @@ -61,8 +61,10 @@ return failure(); Value resultShape = reifiedResultShapes[dimValue.getResultNumber()]; - auto resultShapeType = resultShape.getType().dyn_cast(); - if (!resultShapeType || !resultShapeType.getElementType().isa()) + auto resultShapeType = + llvm::dyn_cast(resultShape.getType()); + if (!resultShapeType || + !llvm::isa(resultShapeType.getElementType())) return failure(); Location loc = dimOp->getLoc(); @@ -82,7 +84,7 @@ LogicalResult matchAndRewrite(OpTy dimOp, PatternRewriter &rewriter) const override { - OpResult dimValue = dimOp.getSource().template dyn_cast(); + OpResult dimValue = llvm::dyn_cast(dimOp.getSource()); if (!dimValue) return failure(); std::optional dimIndex = dimOp.getConstantIndex(); diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -38,14 +38,14 @@ void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc) const { auto castOp = cast(op); - auto srcType = castOp.getSource().getType().cast(); + auto srcType = llvm::cast(castOp.getSource().getType()); // Nothing to check if the result is an unranked memref. - auto resultType = castOp.getType().dyn_cast(); + auto resultType = llvm::dyn_cast(castOp.getType()); if (!resultType) return; - if (srcType.isa()) { + if (llvm::isa(srcType)) { // Check rank. Value srcRank = builder.create(loc, castOp.getSource()); Value resultRank = @@ -75,7 +75,7 @@ // Check dimension sizes. for (const auto &it : llvm::enumerate(resultType.getShape())) { // Static dim size -> static/dynamic dim size does not need verification. - if (auto rankedSrcType = srcType.dyn_cast()) + if (auto rankedSrcType = llvm::dyn_cast(srcType)) if (!rankedSrcType.isDynamicDim(it.index())) continue; diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -38,9 +38,9 @@ Attribute memorySpace = type.getMemorySpace(); if (!memorySpace) return false; - if (auto intAttr = memorySpace.dyn_cast()) + if (auto intAttr = llvm::dyn_cast(memorySpace)) return intAttr.getInt() == NVGPUDialect::kSharedMemoryAddressSpace; - if (auto gpuAttr = memorySpace.dyn_cast()) + if (auto gpuAttr = llvm::dyn_cast(memorySpace)) return gpuAttr.getValue() == gpu::AddressSpace::Workgroup; return false; } @@ -61,8 +61,8 @@ } LogicalResult DeviceAsyncCopyOp::verify() { - auto srcMemref = getSrc().getType().cast(); - auto dstMemref = getDst().getType().cast(); + auto srcMemref = llvm::cast(getSrc().getType()); + auto dstMemref = llvm::cast(getDst().getType()); if (!isLastMemrefDimUnitStride(srcMemref)) return emitError("source memref most minor dim must have unit stride"); @@ -246,10 +246,10 @@ LogicalResult LdMatrixOp::verify() { // ldmatrix reads data from source in shared memory - auto srcMemref = getSrcMemref().getType().cast(); + auto srcMemref = llvm::cast(getSrcMemref().getType()); // ldmatrix writes data to result/destination in vector registers - auto resVector = getRes().getType().cast(); + auto resVector = llvm::cast(getRes().getType()); // vector register shape, element type, and bitwidth ArrayRef resShape = resVector.getShape(); diff --git a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp --- a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp @@ -42,7 +42,9 @@ Location location = op->getLoc(); if (op->hasAttr(op.getTf32EnabledAttrName()) || - !op.getMatrixA().getType().cast().getElementType().isF32()) + !llvm::cast(op.getMatrixA().getType()) + .getElementType() + .isF32()) return failure(); if (precision == MmaSyncF32Lowering::Unkown) diff --git a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp --- a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp @@ -180,7 +180,7 @@ mlir::LogicalResult mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp, Value memrefValue) { - auto memRefType = memrefValue.getType().dyn_cast(); + auto memRefType = llvm::dyn_cast(memrefValue.getType()); if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType)) return failure(); diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp --- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp +++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp @@ -63,7 +63,7 @@ info.vectorType = writeOp.getVectorType(); } else if (isa(op)) { - info.vectorType = op->getResult(0).getType().cast(); + info.vectorType = llvm::cast(op->getResult(0).getType()); } else { return op->emitError() << "unhandled operation type in nvgpu.mma.sync conversion path"; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -41,7 +41,7 @@ struct PointerLikeModel : public PointerLikeType::ExternalModel, T> { Type getElementType(Type pointer) const { - return pointer.cast().getElementType(); + return llvm::cast(pointer).getElementType(); } }; @@ -231,7 +231,7 @@ // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section for (unsigned i = 0; i < (*alignmentValues).size(); ++i) { - if (auto intAttr = (*alignmentValues)[i].dyn_cast()) { + if (auto intAttr = llvm::dyn_cast((*alignmentValues)[i])) { if (intAttr.getValue().sle(0)) return op->emitOpError() << "alignment should be greater than 0"; } else { @@ -463,7 +463,7 @@ return op->emitOpError() << "accumulator variable used more than once"; Type varType = accum.getType(); - auto symbolRef = std::get<1>(args).cast(); + auto symbolRef = llvm::cast(std::get<1>(args)); auto decl = SymbolTable::lookupNearestSymbolFrom(op, symbolRef); if (!decl) @@ -521,7 +521,8 @@ if (i != 0) p << ", "; p << stringifyClauseTaskDepend( - (*depends)[i].cast().getValue()) + llvm::cast((*depends)[i]) + .getValue()) << " -> " << dependVars[i] << " : " << dependTypes[i]; } } @@ -723,8 +724,8 @@ Value mapOp = map_operands[i]; Attribute mapTypeOp = map_types[i]; - assert(mapTypeOp.isa()); - mapTypeBits = mapTypeOp.cast().getInt(); + assert(llvm::isa(mapTypeOp)); + mapTypeBits = llvm::cast(mapTypeOp).getInt(); bool always = bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS); @@ -1018,8 +1019,8 @@ atomicReductionEntryBlock.getArgumentTypes()[1]) return emitOpError() << "expects atomic reduction region with two " "arguments of the same type"; - auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] - .dyn_cast(); + auto ptrType = llvm::dyn_cast( + atomicReductionEntryBlock.getArgumentTypes()[0]); if (!ptrType || (ptrType.getElementType() && ptrType.getElementType() != getType())) return emitOpError() << "expects atomic reduction region arguments to " @@ -1210,7 +1211,7 @@ } } Type elementType = - getAddress().getType().cast().getElementType(); + llvm::cast(getAddress().getType()).getElementType(); if (elementType && elementType != getValue().getType()) return emitError("address must dereference to value type"); return verifySynchronizationHint(*this, getHintVal()); @@ -1261,7 +1262,8 @@ if (getRegion().getNumArguments() != 1) return emitError("the region must accept exactly one argument"); - Type elementType = getX().getType().cast().getElementType(); + Type elementType = + llvm::cast(getX().getType()).getElementType(); if (elementType && elementType != getRegion().getArgument(0).getType()) { return emitError("the type of the operand must be a pointer type whose " "element type is the same as that of the region argument"); diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -465,7 +465,7 @@ } LogicalResult ResultsOp::verify() { - if (!getIndex() && getType().isa()) { + if (!getIndex() && llvm::isa(getType())) { return emitOpError() << "expected `pdl.range` result type when " "no index is specified, but got: " << getType(); diff --git a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp --- a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp @@ -60,7 +60,7 @@ } Type pdl::getRangeElementTypeOrSelf(Type type) { - if (auto rangeType = type.dyn_cast()) + if (auto rangeType = llvm::dyn_cast(type)) return rangeType.getElementType(); return type; } @@ -78,7 +78,7 @@ if (!elementType || parser.parseGreater()) return Type(); - if (elementType.isa()) { + if (llvm::isa(elementType)) { parser.emitError(elementLoc) << "element of pdl.range cannot be another range, but got" << elementType; @@ -95,7 +95,7 @@ LogicalResult RangeType::verify(function_ref emitError, Type elementType) { - if (!elementType.isa() || elementType.isa()) { + if (!llvm::isa(elementType) || llvm::isa(elementType)) { return emitError() << "expected element of pdl.range to be one of [!pdl.attribute, " "!pdl.operation, !pdl.type, !pdl.value], but got " diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp --- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp +++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp @@ -145,7 +145,7 @@ if (initLoop) { // Create the block and the loop variable. // FIXME: Allow passing in a proper location for the loop variable. - auto rangeType = range.getType().cast(); + auto rangeType = llvm::cast(range.getType()); state.regions.front()->emplaceBlock(); state.regions.front()->addArgument(rangeType.getElementType(), state.location); @@ -238,7 +238,8 @@ /// Given the result type of a `GetValueTypeOp`, return the expected input type. static Type getGetValueTypeOpValueType(Type type) { Type valueTy = pdl::ValueType::get(type.getContext()); - return type.isa() ? pdl::RangeType::get(valueTy) : valueTy; + return llvm::isa(type) ? pdl::RangeType::get(valueTy) + : valueTy; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -35,7 +35,7 @@ // Verify that the storage type is integral. // This restriction may be lifted at some point in favor of using bf16 // or f16 as exact representations on hardware where that is advantageous. - auto intStorageType = storageType.dyn_cast(); + auto intStorageType = llvm::dyn_cast(storageType); if (!intStorageType) return emitError() << "storage type must be integral"; unsigned integralWidth = intStorageType.getWidth(); @@ -83,8 +83,8 @@ } bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) { - if (candidateExpressedType.isa()) { - return candidateExpressedType.cast().getElementType() == + if (llvm::isa(candidateExpressedType)) { + return llvm::cast(candidateExpressedType).getElementType() == getExpressedType(); } return candidateExpressedType == getExpressedType(); @@ -92,12 +92,12 @@ QuantizedType QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) { - if (primitiveOrContainerType.isa()) { + if (llvm::isa(primitiveOrContainerType)) { Type elementType = - primitiveOrContainerType.cast().getElementType(); - return elementType.dyn_cast(); + llvm::cast(primitiveOrContainerType).getElementType(); + return llvm::dyn_cast(elementType); } - return primitiveOrContainerType.dyn_cast(); + return llvm::dyn_cast(primitiveOrContainerType); } Type QuantizedType::castFromStorageType(Type candidateType) { @@ -105,18 +105,19 @@ // i.e. i32 -> quant<"uniform[i8:f32]{1.0}"> return *this; } - if (candidateType.isa()) { + if (llvm::isa(candidateType)) { // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> return RankedTensorType::get( - candidateType.cast().getShape(), getStorageType()); + llvm::cast(candidateType).getShape(), + getStorageType()); } - if (candidateType.isa()) { + if (llvm::isa(candidateType)) { // i.e. tensor -> tensor> return UnrankedTensorType::get(getStorageType()); } - if (candidateType.isa()) { + if (llvm::isa(candidateType)) { // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - return VectorType::get(candidateType.cast().getShape(), + return VectorType::get(llvm::cast(candidateType).getShape(), getStorageType()); } @@ -124,25 +125,25 @@ } Type QuantizedType::castToStorageType(Type quantizedType) { - if (quantizedType.isa()) { + if (llvm::isa(quantizedType)) { // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8 - return quantizedType.cast().getStorageType(); + return llvm::cast(quantizedType).getStorageType(); } - if (quantizedType.isa()) { + if (llvm::isa(quantizedType)) { // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - ShapedType sType = quantizedType.cast(); - if (!sType.getElementType().isa()) { + ShapedType sType = llvm::cast(quantizedType); + if (!llvm::isa(sType.getElementType())) { return nullptr; } Type storageType = - sType.getElementType().cast().getStorageType(); - if (quantizedType.isa()) { + llvm::cast(sType.getElementType()).getStorageType(); + if (llvm::isa(quantizedType)) { return RankedTensorType::get(sType.getShape(), storageType); } - if (quantizedType.isa()) { + if (llvm::isa(quantizedType)) { return UnrankedTensorType::get(storageType); } - if (quantizedType.isa()) { + if (llvm::isa(quantizedType)) { return VectorType::get(sType.getShape(), storageType); } } @@ -155,21 +156,21 @@ // i.e. f32 -> quant<"uniform[i8:f32]{1.0}"> return *this; } - if (candidateType.isa()) { - ShapedType candidateShapedType = candidateType.cast(); + if (llvm::isa(candidateType)) { + ShapedType candidateShapedType = llvm::cast(candidateType); if (candidateShapedType.getElementType() != getExpressedType()) { return nullptr; } - if (candidateType.isa()) { + if (llvm::isa(candidateType)) { // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> return RankedTensorType::get(candidateShapedType.getShape(), *this); } - if (candidateType.isa()) { + if (llvm::isa(candidateType)) { // i.e. tensor -> tensor> return UnrankedTensorType::get(*this); } - if (candidateType.isa()) { + if (llvm::isa(candidateType)) { // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> return VectorType::get(candidateShapedType.getShape(), *this); } @@ -179,25 +180,25 @@ } Type QuantizedType::castToExpressedType(Type quantizedType) { - if (quantizedType.isa()) { + if (llvm::isa(quantizedType)) { // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32 - return quantizedType.cast().getExpressedType(); + return llvm::cast(quantizedType).getExpressedType(); } - if (quantizedType.isa()) { + if (llvm::isa(quantizedType)) { // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - ShapedType sType = quantizedType.cast(); - if (!sType.getElementType().isa()) { + ShapedType sType = llvm::cast(quantizedType); + if (!llvm::isa(sType.getElementType())) { return nullptr; } Type expressedType = - sType.getElementType().cast().getExpressedType(); - if (quantizedType.isa()) { + llvm::cast(sType.getElementType()).getExpressedType(); + if (llvm::isa(quantizedType)) { return RankedTensorType::get(sType.getShape(), expressedType); } - if (quantizedType.isa()) { + if (llvm::isa(quantizedType)) { return UnrankedTensorType::get(expressedType); } - if (quantizedType.isa()) { + if (llvm::isa(quantizedType)) { return VectorType::get(sType.getShape(), expressedType); } } @@ -243,7 +244,7 @@ // Verify that the expressed type is floating point. // If this restriction is ever eliminated, the parser/printer must be // extended. - if (expressedType && !expressedType.isa()) + if (expressedType && !llvm::isa(expressedType)) return emitError() << "expressed type must be floating point"; return success(); @@ -284,7 +285,7 @@ // Verify that the expressed type is floating point. // If this restriction is ever eliminated, the parser/printer must be // extended. - if (!expressedType.isa()) + if (!llvm::isa(expressedType)) return emitError() << "expressed type must be floating point"; // Verify scale. @@ -338,7 +339,7 @@ // Verify that the expressed type is floating point. // If this restriction is ever eliminated, the parser/printer must be // extended. - if (!expressedType.isa()) + if (!llvm::isa(expressedType)) return emitError() << "expressed type must be floating point"; // Ensure that the number of scales and zeroPoints match. @@ -385,7 +386,7 @@ // Verify that the expressed type is floating point. // If this restriction is ever eliminated, the parser/printer must be // extended. - if (!expressedType.isa()) + if (!llvm::isa(expressedType)) return emitError() << "expressed type must be floating point"; if (max <= min) return emitError() << "illegal min and max: (" << min << ":" << max << ")"; diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp --- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp @@ -420,13 +420,13 @@ /// Print a type registered to this dialect. void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const { - if (auto anyType = type.dyn_cast()) + if (auto anyType = llvm::dyn_cast(type)) printAnyQuantizedType(anyType, os); - else if (auto uniformType = type.dyn_cast()) + else if (auto uniformType = llvm::dyn_cast(type)) printUniformQuantizedType(uniformType, os); - else if (auto perAxisType = type.dyn_cast()) + else if (auto perAxisType = llvm::dyn_cast(type)) printUniformQuantizedPerAxisType(perAxisType, os); - else if (auto calibratedType = type.dyn_cast()) + else if (auto calibratedType = llvm::dyn_cast(type)) printCalibratedQuantizedType(calibratedType, os); else llvm_unreachable("Unhandled quantized type"); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -495,7 +495,7 @@ Region &ForOp::getLoopBody() { return getRegion(); } ForOp mlir::scf::getForInductionVarOwner(Value val) { - auto ivArg = val.dyn_cast(); + auto ivArg = llvm::dyn_cast(val); if (!ivArg) return ForOp(); assert(ivArg.getOwner() && "unlinked block argument"); @@ -576,7 +576,7 @@ }; Value srcVal = mapping.lookupOrDefault(src); - if (srcVal.getType().isa()) { + if (llvm::isa(srcVal.getType())) { results.push_back(rewriter.create( forallOp.getLoc(), dst.getType(), srcVal, mapping.lookupOrDefault(dst), @@ -890,7 +890,8 @@ replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand, Value replacement) { Type oldType = operand.get().getType(), newType = replacement.getType(); - assert(oldType.isa() && newType.isa() && + assert(llvm::isa(oldType) && + llvm::isa(newType) && "expected ranked tensor types"); // 1. Create new iter operands, exactly 1 is replaced. @@ -1074,7 +1075,7 @@ cast(forOp.getRegion().front().getTerminator()); Value yieldVal = yieldOp->getOperand(idx); auto tensorLoadOp = yieldVal.getDefiningOp(); - bool isTensor = bbArg.getType().isa(); + bool isTensor = llvm::isa(bbArg.getType()); bufferization::ToMemrefOp tensorToMemref; // Either bbArg has no use or it has a single buffer_cast use. @@ -1445,7 +1446,7 @@ } ForallOp mlir::scf::getForallOpThreadIndexOwner(Value val) { - auto tidxArg = val.dyn_cast(); + auto tidxArg = llvm::dyn_cast(val); if (!tidxArg) return ForallOp(); assert(tidxArg.getOwner() && "unlinked block argument"); @@ -1464,7 +1465,8 @@ if (!forallOp) return failure(); Value sharedOut = - forallOp.getTiedOpOperand(dimOp.getSource().cast())->get(); + forallOp.getTiedOpOperand(llvm::cast(dimOp.getSource())) + ->get(); rewriter.updateRootInPlace( dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); }); return success(); @@ -1744,7 +1746,7 @@ llvm::map_range(getYieldingOps(), [](Operation &op) { // Add new ops here as needed. auto insertSliceOp = cast(&op); - return insertSliceOp.getDest().cast(); + return llvm::cast(insertSliceOp.getDest()); })); } @@ -1964,7 +1966,7 @@ // Otherwise, the successor is dependent on the condition. bool condition; - if (auto condAttr = operands.front().dyn_cast_or_null()) { + if (auto condAttr = llvm::dyn_cast_or_null(operands.front())) { condition = condAttr.getValue().isOne(); } else { // If the condition isn't constant, both regions may be executed. @@ -2006,7 +2008,7 @@ void IfOp::getRegionInvocationBounds( ArrayRef operands, SmallVectorImpl &invocationBounds) { - if (auto cond = operands[0].dyn_cast_or_null()) { + if (auto cond = llvm::dyn_cast_or_null(operands[0])) { // If the condition is known, then one region is known to be executed once // and the other zero times. invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0); @@ -2542,7 +2544,7 @@ // come from the same scf.if. for (const auto &tup : llvm::enumerate(thenYield)) { if (tup.value().getDefiningOp() == nestedIf) { - auto nestedIdx = tup.value().cast().getResultNumber(); + auto nestedIdx = llvm::cast(tup.value()).getResultNumber(); if (nestedIf.elseYield().getOperand(nestedIdx) != elseYield[tup.index()]) { return failure(); @@ -2818,7 +2820,7 @@ Region &ParallelOp::getLoopBody() { return getRegion(); } ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) { - auto ivArg = val.dyn_cast(); + auto ivArg = llvm::dyn_cast(val); if (!ivArg) return ParallelOp(); assert(ivArg.getOwner() && "unlinked block argument"); @@ -3130,7 +3132,7 @@ // Try to narrow the successor to the condition region. assert(!operands.empty() && "expected at least one operand"); - auto cond = operands[0].dyn_cast_or_null(); + auto cond = llvm::dyn_cast_or_null(operands[0]); if (!cond || !cond.getValue()) regions.emplace_back(getResults()); if (!cond || cond.getValue()) @@ -3360,7 +3362,7 @@ // block argument or the initial value of i-th before block argument. If // the comparison results `true`, i-th before block argument is a loop // invariant. - auto yieldOpBlockArg = yieldOpArg.dyn_cast(); + auto yieldOpBlockArg = llvm::dyn_cast(yieldOpArg); if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) { Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()]; if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) { @@ -3392,7 +3394,7 @@ // before block argument or the initial value of i-th before block // argument. If the comparison results `true`, i-th before block // argument is a loop invariant. - auto yieldOpBlockArg = yieldOpArg.dyn_cast(); + auto yieldOpBlockArg = llvm::dyn_cast(yieldOpArg); if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) { Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()]; if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) { @@ -3960,7 +3962,7 @@ } // If a constant was not provided, all regions are possible successors. - auto operandValue = operands.front().dyn_cast_or_null(); + auto operandValue = llvm::dyn_cast_or_null(operands.front()); if (!operandValue) { for (Region &caseRegion : getCaseRegions()) successors.emplace_back(&caseRegion); @@ -3981,7 +3983,7 @@ void IndexSwitchOp::getRegionInvocationBounds( ArrayRef operands, SmallVectorImpl &bounds) { - auto operandValue = operands.front().dyn_cast_or_null(); + auto operandValue = llvm::dyn_cast_or_null(operands.front()); if (!operandValue) { // All regions are invoked at most once. bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1)); diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp @@ -27,10 +27,10 @@ ValueBoundsConstraintSet &cstr) { // `value` is an iter_arg or an OpResult. int64_t iterArgIdx; - if (auto iterArg = value.dyn_cast()) { + if (auto iterArg = llvm::dyn_cast(value)) { iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars(); } else { - iterArgIdx = value.cast().getResultNumber(); + iterArgIdx = llvm::cast(value).getResultNumber(); } // An EQ constraint can be added if the yielded value (dimension size) @@ -63,7 +63,7 @@ bound, boundOperands, BoundType::EQ, yieldedValue, dim, [&](Value v, std::optional d) { // Stop when reaching a block argument of the loop body. - if (auto bbArg = v.dyn_cast()) + if (auto bbArg = llvm::dyn_cast(v)) return bbArg.getOwner()->getParentOp() == forOp; // Stop when reaching a value that is defined outside of the loop. It // is impossible to reach an iter_arg from there. diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -30,8 +30,9 @@ /// Helper function for loop bufferization. Cast the given buffer to the given /// memref type. static Value castBuffer(OpBuilder &b, Value buffer, Type type) { - assert(type.isa() && "expected BaseMemRefType"); - assert(buffer.getType().isa() && "expected BaseMemRefType"); + assert(llvm::isa(type) && "expected BaseMemRefType"); + assert(llvm::isa(buffer.getType()) && + "expected BaseMemRefType"); // If the buffer already has the correct type, no cast is needed. if (buffer.getType() == type) return buffer; @@ -78,7 +79,7 @@ SmallVector newArgs; for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { Value value = it.value(); - if (value.getType().isa()) { + if (llvm::isa(value.getType())) { FailureOr maybeBuffer = getBuffer(rewriter, value, options); if (failed(maybeBuffer)) return failure(); @@ -141,7 +142,7 @@ rewriter.setInsertionPointAfter(newOp); SmallVector newResults; for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { - if (it.value().isa()) { + if (llvm::isa(it.value())) { newResults.push_back(rewriter.create( executeRegionOp.getLoc(), newOp->getResult(it.index()))); } else { @@ -183,7 +184,7 @@ // Compute bufferized result types. SmallVector newTypes; for (Value result : ifOp.getResults()) { - if (!result.getType().isa()) { + if (!llvm::isa(result.getType())) { newTypes.push_back(result.getType()); continue; } @@ -218,13 +219,13 @@ assert(value.getDefiningOp() == op && "invalid valid"); // Determine buffer types of the true/false branches. - auto opResult = value.cast(); + auto opResult = llvm::cast(value); auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber()); auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber()); BaseMemRefType thenBufferType, elseBufferType; - if (thenValue.getType().isa()) { + if (llvm::isa(thenValue.getType())) { // True branch was already bufferized. - thenBufferType = thenValue.getType().cast(); + thenBufferType = llvm::cast(thenValue.getType()); } else { auto maybeBufferType = bufferization::getBufferType(thenValue, options, fixedTypes); @@ -232,9 +233,9 @@ return failure(); thenBufferType = *maybeBufferType; } - if (elseValue.getType().isa()) { + if (llvm::isa(elseValue.getType())) { // False branch was already bufferized. - elseBufferType = elseValue.getType().cast(); + elseBufferType = llvm::cast(elseValue.getType()); } else { auto maybeBufferType = bufferization::getBufferType(elseValue, options, fixedTypes); @@ -253,7 +254,8 @@ // Layout maps are different: Promote to fully dynamic layout map. return getMemRefTypeWithFullyDynamicLayout( - opResult.getType().cast(), thenBufferType.getMemorySpace()); + llvm::cast(opResult.getType()), + thenBufferType.getMemorySpace()); } }; @@ -262,7 +264,7 @@ static DenseSet getTensorIndices(ValueRange values) { DenseSet result; for (const auto &it : llvm::enumerate(values)) - if (it.value().getType().isa()) + if (llvm::isa(it.value().getType())) result.insert(it.index()); return result; } @@ -275,8 +277,8 @@ unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size()); DenseSet result; for (unsigned int i = 0; i < minSize; ++i) { - if (!bbArgs[i].getType().isa() || - !yieldedValues[i].getType().isa()) + if (!llvm::isa(bbArgs[i].getType()) || + !llvm::isa(yieldedValues[i].getType())) continue; if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i])) result.insert(i); @@ -291,7 +293,7 @@ const BufferizationOptions &options) { SmallVector result; for (OpOperand &opOperand : operands) { - if (opOperand.get().getType().isa()) { + if (llvm::isa(opOperand.get().getType())) { FailureOr resultBuffer = getBuffer(rewriter, opOperand.get(), options); if (failed(resultBuffer)) @@ -361,9 +363,9 @@ // Compute the buffer type of the yielded value. BaseMemRefType yieldedValueBufferType; - if (yieldedValue.getType().isa()) { + if (llvm::isa(yieldedValue.getType())) { // scf.yield was already bufferized. - yieldedValueBufferType = yieldedValue.getType().cast(); + yieldedValueBufferType = llvm::cast(yieldedValue.getType()); } else { auto maybeBufferType = bufferization::getBufferType(yieldedValue, options, newFixedTypes); @@ -379,7 +381,7 @@ // If there is a mismatch between the yielded buffer type and the iter_arg // buffer type, the buffer type must be promoted to a fully dynamic layout // map. - auto yieldedRanked = yieldedValueBufferType.cast(); + auto yieldedRanked = llvm::cast(yieldedValueBufferType); #ifndef NDEBUG auto iterRanked = initArgBufferType->cast(); assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) && @@ -388,7 +390,7 @@ "expected same memory space"); #endif // NDEBUG return getMemRefTypeWithFullyDynamicLayout( - iterArg.getType().cast(), + llvm::cast(iterArg.getType()), yieldedRanked.getMemorySpace()); } @@ -516,16 +518,16 @@ const DenseMap &fixedTypes) const { auto forOp = cast(op); assert(getOwnerOfValue(value) == op && "invalid value"); - assert(value.getType().isa() && "expected tensor type"); + assert(llvm::isa(value.getType()) && "expected tensor type"); // Get result/argument number. unsigned resultNum; - if (auto bbArg = value.dyn_cast()) { + if (auto bbArg = llvm::dyn_cast(value)) { resultNum = forOp.getResultForOpOperand(forOp.getOpOperandForRegionIterArg(bbArg)) .getResultNumber(); } else { - resultNum = value.cast().getResultNumber(); + resultNum = llvm::cast(value).getResultNumber(); } // Compute the bufferized type. @@ -560,7 +562,7 @@ Value initArg = it.value(); Value result = forOp->getResult(it.index()); // If the type is not a tensor, bufferization doesn't need to touch it. - if (!result.getType().isa()) { + if (!llvm::isa(result.getType())) { castedInitArgs.push_back(initArg); continue; } @@ -611,7 +613,7 @@ auto yieldOp = cast(forOp.getLoopBody().front().getTerminator()); for (OpResult opResult : op->getOpResults()) { - if (!opResult.getType().isa()) + if (!llvm::isa(opResult.getType())) continue; // Note: This is overly strict. We should check for aliasing bufferized @@ -736,7 +738,7 @@ for (int64_t idx = 0; idx < static_cast(conditionOp.getArgs().size()); ++idx) { Value value = conditionOp.getArgs()[idx]; - if (!value.getType().isa() || + if (!llvm::isa(value.getType()) || (equivalentYieldsAfter.contains(idx) && equivalentYieldsBefore.contains(idx))) { beforeYieldValues.push_back(value); @@ -786,7 +788,7 @@ Value initArg = it.value(); Value beforeArg = whileOp.getBeforeArguments()[it.index()]; // If the type is not a tensor, bufferization doesn't need to touch it. - if (!beforeArg.getType().isa()) { + if (!llvm::isa(beforeArg.getType())) { castedInitArgs.push_back(initArg); continue; } @@ -799,7 +801,7 @@ // The result types of a WhileOp are the same as the "after" bbArg types. SmallVector argsTypesAfter = llvm::to_vector( llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) { - if (!bbArg.getType().isa()) + if (!llvm::isa(bbArg.getType())) return bbArg.getType(); // TODO: error handling return bufferization::getBufferType(bbArg, options)->cast(); @@ -848,10 +850,10 @@ const DenseMap &fixedTypes) const { auto whileOp = cast(op); assert(getOwnerOfValue(value) == op && "invalid value"); - assert(value.getType().isa() && "expected tensor type"); + assert(llvm::isa(value.getType()) && "expected tensor type"); // Case 1: Block argument of the "before" region. - if (auto bbArg = value.dyn_cast()) { + if (auto bbArg = llvm::dyn_cast(value)) { if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) { Value initArg = whileOp.getInits()[bbArg.getArgNumber()]; auto yieldOp = whileOp.getYieldOp(); @@ -865,18 +867,18 @@ // The bufferized "after" bbArg type can be directly computed from the // bufferized "before" bbArg type. unsigned resultNum; - if (auto opResult = value.dyn_cast()) { + if (auto opResult = llvm::dyn_cast(value)) { resultNum = opResult.getResultNumber(); - } else if (value.cast().getOwner()->getParent() == + } else if (llvm::cast(value).getOwner()->getParent() == &whileOp.getAfter()) { - resultNum = value.cast().getArgNumber(); + resultNum = llvm::cast(value).getArgNumber(); } else { llvm_unreachable("invalid value"); } Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum]; - if (!conditionYieldedVal.getType().isa()) { + if (!llvm::isa(conditionYieldedVal.getType())) { // scf.condition was already bufferized. - return conditionYieldedVal.getType().cast(); + return llvm::cast(conditionYieldedVal.getType()); } return bufferization::getBufferType(conditionYieldedVal, options, fixedTypes); @@ -902,7 +904,7 @@ auto conditionOp = whileOp.getConditionOp(); for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { - if (!it.value().getType().isa()) + if (!llvm::isa(it.value().getType())) continue; if (!state.areEquivalentBufferizedValues( it.value(), conditionOp->getBlock()->getArgument(it.index()))) @@ -913,7 +915,7 @@ auto yieldOp = whileOp.getYieldOp(); for (const auto &it : llvm::enumerate(yieldOp.getResults())) { - if (!it.value().getType().isa()) + if (!llvm::isa(it.value().getType())) continue; if (!state.areEquivalentBufferizedValues( it.value(), yieldOp->getBlock()->getArgument(it.index()))) @@ -971,7 +973,7 @@ SmallVector newResults; for (const auto &it : llvm::enumerate(yieldOp.getResults())) { Value value = it.value(); - if (value.getType().isa()) { + if (llvm::isa(value.getType())) { FailureOr maybeBuffer = getBuffer(rewriter, value, options); if (failed(maybeBuffer)) return failure(); @@ -1110,7 +1112,7 @@ const DenseMap &fixedTypes) const { auto forallOp = cast(op); - if (auto bbArg = value.dyn_cast()) + if (auto bbArg = llvm::dyn_cast(value)) // A tensor block argument has the same bufferized type as the // corresponding output operand. return bufferization::getBufferType( @@ -1119,7 +1121,7 @@ // The bufferized result type is the same as the bufferized type of the // corresponding output operand. return bufferization::getBufferType( - forallOp.getOutputs()[value.cast().getResultNumber()], + forallOp.getOutputs()[llvm::cast(value).getResultNumber()], options, fixedTypes); } diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -43,7 +43,7 @@ while (value) { if (value == forOp.getRegionIterArgs()[arg]) return true; - OpResult opResult = value.dyn_cast(); + OpResult opResult = llvm::dyn_cast(value); if (!opResult) return false; @@ -91,7 +91,7 @@ LogicalResult matchAndRewrite(OpTy dimOp, PatternRewriter &rewriter) const override { - auto blockArg = dimOp.getSource().template dyn_cast(); + auto blockArg = llvm::dyn_cast(dimOp.getSource()); if (!blockArg) return failure(); auto forOp = dyn_cast(blockArg.getParentBlock()->getParentOp()); @@ -139,7 +139,7 @@ auto forOp = dimOp.getSource().template getDefiningOp(); if (!forOp) return failure(); - auto opResult = dimOp.getSource().template cast(); + auto opResult = llvm::cast(dimOp.getSource()); unsigned resultNumber = opResult.getResultNumber(); if (!isShapePreserving(forOp, resultNumber)) return failure(); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp @@ -145,18 +145,20 @@ } uint32_t spirv::InterfaceVarABIAttr::getBinding() { - return getImpl()->binding.cast().getInt(); + return llvm::cast(getImpl()->binding).getInt(); } uint32_t spirv::InterfaceVarABIAttr::getDescriptorSet() { - return getImpl()->descriptorSet.cast().getInt(); + return llvm::cast(getImpl()->descriptorSet).getInt(); } std::optional spirv::InterfaceVarABIAttr::getStorageClass() { if (getImpl()->storageClass) return static_cast( - getImpl()->storageClass.cast().getValue().getZExtValue()); + llvm::cast(getImpl()->storageClass) + .getValue() + .getZExtValue()); return std::nullopt; } @@ -170,7 +172,7 @@ return emitError() << "expected 32-bit integer for binding"; if (storageClass) { - if (auto storageClassAttr = storageClass.cast()) { + if (auto storageClassAttr = llvm::cast(storageClass)) { auto storageClassValue = spirv::symbolizeStorageClass(storageClassAttr.getInt()); if (!storageClassValue) @@ -219,14 +221,14 @@ spirv::Version spirv::VerCapExtAttr::getVersion() { return static_cast( - getImpl()->version.cast().getValue().getZExtValue()); + llvm::cast(getImpl()->version).getValue().getZExtValue()); } spirv::VerCapExtAttr::ext_iterator::ext_iterator(ArrayAttr::iterator it) : llvm::mapped_iterator( it, [](Attribute attr) { - return *symbolizeExtension(attr.cast().getValue()); + return *symbolizeExtension(llvm::cast(attr).getValue()); }) {} spirv::VerCapExtAttr::ext_range spirv::VerCapExtAttr::getExtensions() { @@ -235,7 +237,7 @@ } ArrayAttr spirv::VerCapExtAttr::getExtensionsAttr() { - return getImpl()->extensions.cast(); + return llvm::cast(getImpl()->extensions); } spirv::VerCapExtAttr::cap_iterator::cap_iterator(ArrayAttr::iterator it) @@ -243,7 +245,7 @@ spirv::Capability (*)(Attribute)>( it, [](Attribute attr) { return *symbolizeCapability( - attr.cast().getValue().getZExtValue()); + llvm::cast(attr).getValue().getZExtValue()); }) {} spirv::VerCapExtAttr::cap_range spirv::VerCapExtAttr::getCapabilities() { @@ -252,7 +254,7 @@ } ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() { - return getImpl()->capabilities.cast(); + return llvm::cast(getImpl()->capabilities); } LogicalResult @@ -263,7 +265,7 @@ return emitError() << "expected 32-bit integer for version"; if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) { - if (auto intAttr = attr.dyn_cast()) + if (auto intAttr = llvm::dyn_cast(attr)) if (spirv::symbolizeCapability(intAttr.getValue().getZExtValue())) return true; return false; @@ -271,7 +273,7 @@ return emitError() << "unknown capability in capability list"; if (!llvm::all_of(extensions.getValue(), [](Attribute attr) { - if (auto strAttr = attr.dyn_cast()) + if (auto strAttr = llvm::dyn_cast(attr)) if (spirv::symbolizeExtension(strAttr.getValue())) return true; return false; @@ -297,7 +299,7 @@ StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; } spirv::VerCapExtAttr spirv::TargetEnvAttr::getTripleAttr() const { - return getImpl()->triple.cast(); + return llvm::cast(getImpl()->triple); } spirv::Version spirv::TargetEnvAttr::getVersion() const { @@ -337,7 +339,7 @@ } spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() const { - return getImpl()->limits.cast(); + return llvm::cast(getImpl()->limits); } //===----------------------------------------------------------------------===// @@ -628,7 +630,7 @@ [&](spirv::Capability cap) { os << spirv::stringifyCapability(cap); }); printer << "], ["; llvm::interleaveComma(triple.getExtensionsAttr(), os, [&](Attribute attr) { - os << attr.cast().getValue(); + os << llvm::cast(attr).getValue(); }); printer << "]>"; } @@ -669,11 +671,11 @@ if (succeeded(generatedAttributePrinter(attr, printer))) return; - if (auto targetEnv = attr.dyn_cast()) + if (auto targetEnv = llvm::dyn_cast(attr)) print(targetEnv, printer); - else if (auto vceAttr = attr.dyn_cast()) + else if (auto vceAttr = llvm::dyn_cast(attr)) print(vceAttr, printer); - else if (auto interfaceVarABIAttr = attr.dyn_cast()) + else if (auto interfaceVarABIAttr = llvm::dyn_cast(attr)) print(interfaceVarABIAttr, printer); else llvm_unreachable("unhandled SPIR-V attribute kind"); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -33,9 +33,9 @@ if (!attr) return std::nullopt; - if (auto boolAttr = attr.dyn_cast()) + if (auto boolAttr = llvm::dyn_cast(attr)) return boolAttr.getValue(); - if (auto splatAttr = attr.dyn_cast()) + if (auto splatAttr = llvm::dyn_cast(attr)) if (splatAttr.getElementType().isInteger(1)) return splatAttr.getSplatValue(); return std::nullopt; @@ -52,12 +52,12 @@ if (indices.empty()) return composite; - if (auto vector = composite.dyn_cast()) { + if (auto vector = llvm::dyn_cast(composite)) { assert(indices.size() == 1 && "must have exactly one index for a vector"); return vector.getValues()[indices[0]]; } - if (auto array = composite.dyn_cast()) { + if (auto array = llvm::dyn_cast(composite)) { assert(!indices.empty() && "must have at least one index for an array"); return extractCompositeElement(array.getValue()[indices[0]], indices.drop_front()); @@ -149,7 +149,7 @@ if (auto constructOp = getComposite().getDefiningOp()) { - auto type = constructOp.getType().cast(); + auto type = llvm::cast(constructOp.getType()); if (getIndices().size() == 1 && constructOp.getConstituents().size() == type.getNumElements()) { auto i = getIndices().begin()->cast(); @@ -159,7 +159,7 @@ auto indexVector = llvm::to_vector<8>(llvm::map_range(getIndices(), [](Attribute attr) { - return static_cast(attr.cast().getInt()); + return static_cast(llvm::cast(attr).getInt()); })); return extractCompositeElement(adaptor.getComposite(), indexVector); } @@ -434,10 +434,9 @@ // "Before version 1.4, Result Type must be a pointer, scalar, or vector. // Starting with version 1.4, Result Type can additionally be a composite type // other than a vector." - bool isScalarOrVector = trueBrStoreOp.getValue() - .getType() - .cast() - .isScalarOrVector(); + bool isScalarOrVector = + llvm::cast(trueBrStoreOp.getValue().getType()) + .isScalarOrVector(); // Check that each `spirv.Store` uses the same pointer, memory access // attributes and a valid type of the value. diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -164,19 +164,19 @@ return type; // Check other allowed types - if (auto t = type.dyn_cast()) { + if (auto t = llvm::dyn_cast(type)) { if (type.isBF16()) { parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types"); return Type(); } - } else if (auto t = type.dyn_cast()) { + } else if (auto t = llvm::dyn_cast(type)) { if (!ScalarType::isValid(t)) { parser.emitError(typeLoc, "only 1/8/16/32/64-bit integer type allowed but found ") << type; return Type(); } - } else if (auto t = type.dyn_cast()) { + } else if (auto t = llvm::dyn_cast(type)) { if (t.getRank() != 1) { parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t; return Type(); @@ -203,7 +203,7 @@ if (parser.parseType(type)) return Type(); - if (auto t = type.dyn_cast()) { + if (auto t = llvm::dyn_cast(type)) { if (t.getRank() != 1) { parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t; return Type(); @@ -216,7 +216,7 @@ return Type(); } - if (!t.getElementType().isa()) { + if (!llvm::isa(t.getElementType())) { parser.emitError(typeLoc, "matrix columns' elements must be of " "Float type, got ") << t.getElementType(); @@ -239,7 +239,7 @@ if (parser.parseType(type)) return Type(); - if (!type.isa()) { + if (!llvm::isa(type)) { parser.emitError(typeLoc, "sampled image must be composed using image type, got ") << type; @@ -939,12 +939,12 @@ Attribute attr = attribute.getValue(); if (symbol == spirv::getEntryPointABIAttrName()) { - if (!attr.isa()) { + if (!llvm::isa(attr)) { return op->emitError("'") << symbol << "' attribute must be an entry point ABI attribute"; } } else if (symbol == spirv::getTargetEnvAttrName()) { - if (!attr.isa()) + if (!llvm::isa(attr)) return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr"; } else { return op->emitError("found unsupported '") @@ -965,7 +965,7 @@ return emitError(loc, "found unsupported '") << symbol << "' attribute on region argument"; - auto varABIAttr = attr.dyn_cast(); + auto varABIAttr = llvm::dyn_cast(attr); if (!varABIAttr) return emitError(loc, "'") << symbol << "' must be a spirv::InterfaceVarABIAttr"; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -81,7 +81,7 @@ parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(type)) return failure(); - auto fnType = type.dyn_cast(); + auto fnType = llvm::dyn_cast(type); if (!fnType) { parser.emitError(loc, "expected function type"); return failure(); @@ -141,7 +141,7 @@ return failure(); } auto valueAttr = constOp.getValue(); - auto integerValueAttr = valueAttr.dyn_cast(); + auto integerValueAttr = llvm::dyn_cast(valueAttr); if (!integerValueAttr) { return failure(); } @@ -181,11 +181,11 @@ if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(), attrName, attr)) return failure(); - if (!attrVal.isa()) + if (!llvm::isa(attrVal)) return parser.emitError(loc, "expected ") << attrName << " attribute specified as string"; - auto attrOptional = - spirv::symbolizeEnum(attrVal.cast().getValue()); + auto attrOptional = spirv::symbolizeEnum( + llvm::cast(attrVal).getValue()); if (!attrOptional) return parser.emitError(loc, "invalid ") << attrName << " attribute specification: " << attrVal; @@ -430,23 +430,23 @@ Type resultType = op->getResult(0).getType(); // ODS checks that result type and operand type have the same shape. - if (auto vectorType = operandType.dyn_cast()) { + if (auto vectorType = llvm::dyn_cast(operandType)) { operandType = vectorType.getElementType(); - resultType = resultType.cast().getElementType(); + resultType = llvm::cast(resultType).getElementType(); } if (auto coopMatrixType = - operandType.dyn_cast()) { + llvm::dyn_cast(operandType)) { operandType = coopMatrixType.getElementType(); resultType = - resultType.cast().getElementType(); + llvm::cast(resultType).getElementType(); } if (auto jointMatrixType = - operandType.dyn_cast()) { + llvm::dyn_cast(operandType)) { operandType = jointMatrixType.getElementType(); resultType = - resultType.cast().getElementType(); + llvm::cast(resultType).getElementType(); } auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth(); @@ -490,7 +490,7 @@ return success(); } - auto memAccess = memAccessAttr.template cast(); + auto memAccess = llvm::cast(memAccessAttr); if (!memAccess) { return memoryOp.emitOpError("invalid memory access specifier: ") @@ -534,7 +534,7 @@ return success(); } - auto memAccess = memAccessAttr.template cast(); + auto memAccess = llvm::cast(memAccessAttr); if (!memAccess) { return memoryOp.emitOpError("invalid memory access specifier: ") @@ -589,7 +589,7 @@ // TODO: Check that the value type satisfies restrictions of // SPIR-V OpLoad/OpStore operations if (val.getType() != - ptr.getType().cast().getPointeeType()) { + llvm::cast(ptr.getType()).getPointeeType()) { return op.emitOpError("mismatch in result type and pointer type"); } return success(); @@ -599,10 +599,11 @@ static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val) { auto valType = val.getType(); - if (auto valVecTy = valType.dyn_cast()) + if (auto valVecTy = llvm::dyn_cast(valType)) valType = valVecTy.getElementType(); - if (valType != ptr.getType().cast().getPointeeType()) { + if (valType != + llvm::cast(ptr.getType()).getPointeeType()) { return op.emitOpError("mismatch in result type and pointer type"); } return success(); @@ -674,7 +675,7 @@ // Get bit width of types. static unsigned getBitWidth(Type type) { - if (type.isa()) { + if (llvm::isa(type)) { // Just return 64 bits for pointer types for now. // TODO: Make sure not caller relies on the actual pointer width value. return 64; @@ -683,7 +684,7 @@ if (type.isIntOrFloat()) return type.getIntOrFloatBitWidth(); - if (auto vectorType = type.dyn_cast()) { + if (auto vectorType = llvm::dyn_cast(type)) { assert(vectorType.getElementType().isIntOrFloat()); return vectorType.getNumElements() * vectorType.getElementType().getIntOrFloatBitWidth(); @@ -703,7 +704,7 @@ } for (auto index : indices) { - if (auto cType = type.dyn_cast()) { + if (auto cType = llvm::dyn_cast(type)) { if (cType.hasCompileTimeKnownNumElements() && (index < 0 || static_cast(index) >= cType.getNumElements())) { @@ -723,7 +724,7 @@ static Type getElementType(Type type, Attribute indices, function_ref emitErrorFn) { - auto indicesArrayAttr = indices.dyn_cast(); + auto indicesArrayAttr = llvm::dyn_cast(indices); if (!indicesArrayAttr) { emitErrorFn("expected a 32-bit integer array attribute for 'indices'"); return nullptr; @@ -735,7 +736,7 @@ SmallVector indexVals; for (auto indexAttr : indicesArrayAttr) { - auto indexIntAttr = indexAttr.dyn_cast(); + auto indexIntAttr = llvm::dyn_cast(indexAttr); if (!indexIntAttr) { emitErrorFn("expected an 32-bit integer for index, but found '") << indexAttr << "'"; @@ -769,7 +770,7 @@ template static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) { - auto resultType = op.getType().template cast(); + auto resultType = llvm::cast(op.getType()); if (resultType.getNumElements() != 2) return op.emitOpError("expected result struct type containing two members"); @@ -794,7 +795,7 @@ if (parser.parseType(resultType)) return failure(); - auto structType = resultType.dyn_cast(); + auto structType = llvm::dyn_cast(resultType); if (!structType || structType.getNumElements() != 2) return parser.emitError(loc, "expected spirv.struct type with two members"); @@ -836,7 +837,7 @@ parser.getCurrentLocation(&loc) || parser.parseColonType(type)) return failure(); - auto ptrType = type.dyn_cast(); + auto ptrType = llvm::dyn_cast(type); if (!ptrType) return parser.emitError(loc, "expected pointer type"); @@ -877,9 +878,9 @@ // Verifies an atomic update op. template static LogicalResult verifyAtomicUpdateOp(Operation *op) { - auto ptrType = op->getOperand(0).getType().cast(); + auto ptrType = llvm::cast(op->getOperand(0).getType()); auto elementType = ptrType.getPointeeType(); - if (!elementType.isa()) + if (!llvm::isa(elementType)) return op->emitOpError() << "pointer operand must point to an " << stringifyTypeName() << " value, found " << elementType; @@ -990,7 +991,7 @@ static Type getUnaryOpResultType(Type operandType) { Builder builder(operandType.getContext()); Type resultType = builder.getIntegerType(1); - if (auto vecType = operandType.dyn_cast()) + if (auto vecType = llvm::dyn_cast(operandType)) return VectorType::get(vecType.getNumElements(), resultType); return resultType; } @@ -1010,7 +1011,7 @@ //===----------------------------------------------------------------------===// static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) { - auto ptrType = type.dyn_cast(); + auto ptrType = llvm::dyn_cast(type); if (!ptrType) { emitError(baseLoc, "'spirv.AccessChain' op expected a pointer " "to composite type, but provided ") @@ -1023,7 +1024,7 @@ int32_t index = 0; for (auto indexSSA : indices) { - auto cType = resultType.dyn_cast(); + auto cType = llvm::dyn_cast(resultType); if (!cType) { emitError( baseLoc, @@ -1032,7 +1033,7 @@ return nullptr; } index = 0; - if (resultType.isa()) { + if (llvm::isa(resultType)) { Operation *op = indexSSA.getDefiningOp(); if (!op) { emitError(baseLoc, "'spirv.AccessChain' op index must be an " @@ -1134,7 +1135,7 @@ return failure(); auto providedResultType = - accessChainOp.getType().template dyn_cast(); + llvm::dyn_cast(accessChainOp.getType()); if (!providedResultType) return accessChainOp.emitOpError( "result type must be a pointer, but provided") @@ -1201,7 +1202,7 @@ if (parser.parseColonType(type)) return failure(); - auto ptrType = type.dyn_cast(); + auto ptrType = llvm::dyn_cast(type); if (!ptrType) return parser.emitError(loc, "expected pointer type"); @@ -1231,10 +1232,9 @@ "result, but found ") << atomOp.getComparator().getType() << " vs " << atomOp.getType(); - Type pointeeType = atomOp.getPointer() - .getType() - .template cast() - .getPointeeType(); + Type pointeeType = + llvm::cast(atomOp.getPointer().getType()) + .getPointeeType(); if (atomOp.getType() != pointeeType) return atomOp.emitOpError( "pointer operand's pointee type must have the same " @@ -1322,7 +1322,7 @@ if (parser.parseColonType(type)) return failure(); - auto ptrType = type.dyn_cast(); + auto ptrType = llvm::dyn_cast(type); if (!ptrType) return parser.emitError(loc, "expected pointer type"); @@ -1340,7 +1340,7 @@ << getValue().getType() << " vs " << getType(); Type pointeeType = - getPointer().getType().cast().getPointeeType(); + llvm::cast(getPointer().getType()).getPointeeType(); if (getType() != pointeeType) return emitOpError("pointer operand's pointee type must have the same " "as the op result type, but found ") @@ -1537,13 +1537,13 @@ if (operandType == resultType) { return emitError("result type must be different from operand type"); } - if (operandType.isa() && - !resultType.isa()) { + if (llvm::isa(operandType) && + !llvm::isa(resultType)) { return emitError( "unhandled bit cast conversion from pointer type to non-pointer type"); } - if (!operandType.isa() && - resultType.isa()) { + if (!llvm::isa(operandType) && + llvm::isa(resultType)) { return emitError( "unhandled bit cast conversion from non-pointer type to pointer type"); } @@ -1562,8 +1562,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::PtrCastToGenericOp::verify() { - auto operandType = getPointer().getType().cast(); - auto resultType = getResult().getType().cast(); + auto operandType = llvm::cast(getPointer().getType()); + auto resultType = llvm::cast(getResult().getType()); spirv::StorageClass operandStorage = operandType.getStorageClass(); if (operandStorage != spirv::StorageClass::Workgroup && @@ -1590,8 +1590,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::GenericCastToPtrOp::verify() { - auto operandType = getPointer().getType().cast(); - auto resultType = getResult().getType().cast(); + auto operandType = llvm::cast(getPointer().getType()); + auto resultType = llvm::cast(getResult().getType()); spirv::StorageClass operandStorage = operandType.getStorageClass(); if (operandStorage != spirv::StorageClass::Generic) @@ -1618,8 +1618,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::GenericCastToPtrExplicitOp::verify() { - auto operandType = getPointer().getType().cast(); - auto resultType = getResult().getType().cast(); + auto operandType = llvm::cast(getPointer().getType()); + auto resultType = llvm::cast(getResult().getType()); spirv::StorageClass operandStorage = operandType.getStorageClass(); if (operandStorage != spirv::StorageClass::Generic) @@ -1719,7 +1719,7 @@ if (auto weights = getBranchWeights()) { printer << " ["; llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) { - printer << a.cast().getInt(); + printer << llvm::cast(a).getInt(); }); printer << "]"; } @@ -1736,7 +1736,7 @@ return emitOpError("must have exactly two branch weights"); } if (llvm::all_of(*weights, [](Attribute attr) { - return attr.cast().getValue().isZero(); + return llvm::cast(attr).getValue().isZero(); })) return emitOpError("branch weights cannot both be zero"); } @@ -1749,10 +1749,10 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::CompositeConstructOp::verify() { - auto cType = getType().cast(); + auto cType = llvm::cast(getType()); operand_range constituents = this->getConstituents(); - if (auto coopType = cType.dyn_cast()) { + if (auto coopType = llvm::dyn_cast(cType)) { if (constituents.size() != 1) return emitOpError("has incorrect number of operands: expected ") << "1, but provided " << constituents.size(); @@ -1763,7 +1763,7 @@ return success(); } - if (auto jointType = cType.dyn_cast()) { + if (auto jointType = llvm::dyn_cast(cType)) { if (constituents.size() != 1) return emitOpError("has incorrect number of operands: expected ") << "1, but provided " << constituents.size(); @@ -1787,7 +1787,7 @@ // If not constructing a cooperative matrix type, then we must be constructing // a vector type. - auto resultType = cType.dyn_cast(); + auto resultType = llvm::dyn_cast(cType); if (!resultType) return emitOpError( "expected to return a vector or cooperative matrix when the number of " @@ -1795,14 +1795,14 @@ SmallVector sizes; for (Value component : constituents) { - if (!component.getType().isa() && + if (!llvm::isa(component.getType()) && !component.getType().isIntOrFloat()) return emitOpError("operand type mismatch: expected operand to have " "a scalar or vector type, but provided ") << component.getType(); Type elementType = component.getType(); - if (auto vectorType = component.getType().dyn_cast()) { + if (auto vectorType = llvm::dyn_cast(component.getType())) { sizes.push_back(vectorType.getNumElements()); elementType = vectorType.getElementType(); } else { @@ -1866,7 +1866,7 @@ } LogicalResult spirv::CompositeExtractOp::verify() { - auto indicesArrayAttr = getIndices().dyn_cast(); + auto indicesArrayAttr = llvm::dyn_cast(getIndices()); auto resultType = getElementType(getComposite().getType(), indicesArrayAttr, getLoc()); if (!resultType) @@ -1909,7 +1909,7 @@ } LogicalResult spirv::CompositeInsertOp::verify() { - auto indicesArrayAttr = getIndices().dyn_cast(); + auto indicesArrayAttr = llvm::dyn_cast(getIndices()); auto objectType = getElementType(getComposite().getType(), indicesArrayAttr, getLoc()); if (!objectType) @@ -1946,9 +1946,9 @@ return failure(); Type type = NoneType::get(parser.getContext()); - if (auto typedAttr = value.dyn_cast()) + if (auto typedAttr = llvm::dyn_cast(value)) type = typedAttr.getType(); - if (type.isa()) { + if (llvm::isa(type)) { if (parser.parseColonType(type)) return failure(); } @@ -1958,25 +1958,25 @@ void spirv::ConstantOp::print(OpAsmPrinter &printer) { printer << ' ' << getValue(); - if (getType().isa()) + if (llvm::isa(getType())) printer << " : " << getType(); } static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType) { - if (value.isa()) { - auto valueType = value.cast().getType(); + if (llvm::isa(value)) { + auto valueType = llvm::cast(value).getType(); if (valueType != opType) return op.emitOpError("result type (") << opType << ") does not match value type (" << valueType << ")"; return success(); } - if (value.isa()) { - auto valueType = value.cast().getType(); + if (llvm::isa(value)) { + auto valueType = llvm::cast(value).getType(); if (valueType == opType) return success(); - auto arrayType = opType.dyn_cast(); - auto shapedType = valueType.dyn_cast(); + auto arrayType = llvm::dyn_cast(opType); + auto shapedType = llvm::dyn_cast(valueType); if (!arrayType) return op.emitOpError("result or element type (") << opType << ") does not match value type (" << valueType @@ -1984,7 +1984,7 @@ int numElements = arrayType.getNumElements(); auto opElemType = arrayType.getElementType(); - while (auto t = opElemType.dyn_cast()) { + while (auto t = llvm::dyn_cast(opElemType)) { numElements *= t.getNumElements(); opElemType = t.getElementType(); } @@ -2005,8 +2005,8 @@ } return success(); } - if (auto arrayAttr = value.dyn_cast()) { - auto arrayType = opType.dyn_cast(); + if (auto arrayAttr = llvm::dyn_cast(value)) { + auto arrayType = llvm::dyn_cast(opType); if (!arrayType) return op.emitOpError( "must have spirv.array result type for array value"); @@ -2030,12 +2030,12 @@ bool spirv::ConstantOp::isBuildableWith(Type type) { // Must be valid SPIR-V type first. - if (!type.isa()) + if (!llvm::isa(type)) return false; if (isa(type.getDialect())) { // TODO: support constant struct - return type.isa(); + return llvm::isa(type); } return true; @@ -2043,7 +2043,7 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc, OpBuilder &builder) { - if (auto intType = type.dyn_cast()) { + if (auto intType = llvm::dyn_cast(type)) { unsigned width = intType.getWidth(); if (width == 1) return builder.create(loc, type, @@ -2051,19 +2051,19 @@ return builder.create( loc, type, builder.getIntegerAttr(type, APInt(width, 0))); } - if (auto floatType = type.dyn_cast()) { + if (auto floatType = llvm::dyn_cast(type)) { return builder.create( loc, type, builder.getFloatAttr(floatType, 0.0)); } - if (auto vectorType = type.dyn_cast()) { + if (auto vectorType = llvm::dyn_cast(type)) { Type elemType = vectorType.getElementType(); - if (elemType.isa()) { + if (llvm::isa(elemType)) { return builder.create( loc, type, DenseElementsAttr::get(vectorType, IntegerAttr::get(elemType, 0).getValue())); } - if (elemType.isa()) { + if (llvm::isa(elemType)) { return builder.create( loc, type, DenseFPElementsAttr::get(vectorType, @@ -2076,7 +2076,7 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc, OpBuilder &builder) { - if (auto intType = type.dyn_cast()) { + if (auto intType = llvm::dyn_cast(type)) { unsigned width = intType.getWidth(); if (width == 1) return builder.create(loc, type, @@ -2084,19 +2084,19 @@ return builder.create( loc, type, builder.getIntegerAttr(type, APInt(width, 1))); } - if (auto floatType = type.dyn_cast()) { + if (auto floatType = llvm::dyn_cast(type)) { return builder.create( loc, type, builder.getFloatAttr(floatType, 1.0)); } - if (auto vectorType = type.dyn_cast()) { + if (auto vectorType = llvm::dyn_cast(type)) { Type elemType = vectorType.getElementType(); - if (elemType.isa()) { + if (llvm::isa(elemType)) { return builder.create( loc, type, DenseElementsAttr::get(vectorType, IntegerAttr::get(elemType, 1).getValue())); } - if (elemType.isa()) { + if (llvm::isa(elemType)) { return builder.create( loc, type, DenseFPElementsAttr::get(vectorType, @@ -2115,9 +2115,9 @@ llvm::raw_svector_ostream specialName(specialNameBuffer); specialName << "cst"; - IntegerType intTy = type.dyn_cast(); + IntegerType intTy = llvm::dyn_cast(type); - if (IntegerAttr intCst = getValue().dyn_cast()) { + if (IntegerAttr intCst = llvm::dyn_cast(getValue())) { if (intTy && intTy.getWidth() == 1) { return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); } @@ -2131,17 +2131,18 @@ } } - if (intTy || type.isa()) { + if (intTy || llvm::isa(type)) { specialName << '_' << type; } - if (auto vecType = type.dyn_cast()) { + if (auto vecType = llvm::dyn_cast(type)) { specialName << "_vec_"; specialName << vecType.getDimSize(0); Type elementType = vecType.getElementType(); - if (elementType.isa() || elementType.isa()) { + if (llvm::isa(elementType) || + llvm::isa(elementType)) { specialName << "x" << elementType; } } @@ -2210,9 +2211,10 @@ auto resultType = getResult().getType(); // ODS checks that vector result type and vector operand type have the same // shape. - if (auto vectorType = operandType.dyn_cast()) { + if (auto vectorType = llvm::dyn_cast(operandType)) { unsigned operandNumElements = vectorType.getNumElements(); - unsigned resultNumElements = resultType.cast().getNumElements(); + unsigned resultNumElements = + llvm::cast(resultType).getNumElements(); if (operandNumElements != resultNumElements) { return emitOpError( "operand and result must have same number of elements"); @@ -2230,9 +2232,10 @@ auto resultType = getResult().getType(); // ODS checks that vector result type and vector operand type have the same // shape. - if (auto vectorType = operandType.dyn_cast()) { + if (auto vectorType = llvm::dyn_cast(operandType)) { unsigned operandNumElements = vectorType.getNumElements(); - unsigned resultNumElements = resultType.cast().getNumElements(); + unsigned resultNumElements = + llvm::cast(resultType).getNumElements(); if (operandNumElements != resultNumElements) { return emitOpError( "operand and result must have same number of elements"); @@ -2331,7 +2334,7 @@ if (parser.parseAttribute(value, i32Type, "value", attr)) { return failure(); } - values.push_back(value.cast().getInt()); + values.push_back(llvm::cast(value).getInt()); } result.addAttribute(kValuesAttrName, parser.getBuilder().getI32ArrayAttr(values)); @@ -2347,7 +2350,7 @@ return; printer << ", "; llvm::interleaveComma(values, printer, [&](Attribute a) { - printer << a.cast().getInt(); + printer << llvm::cast(a).getInt(); }); } @@ -2677,7 +2680,7 @@ if (parser.parseColonType(type)) { return failure(); } - if (!type.isa()) { + if (!llvm::isa(type)) { return parser.emitError(loc, "expected spirv.ptr type"); } result.addAttribute(kTypeAttrName, TypeAttr::get(type)); @@ -2708,7 +2711,7 @@ } LogicalResult spirv::GlobalVariableOp::verify() { - if (!getType().isa()) + if (!llvm::isa(getType())) return emitOpError("result must be of a !spv.ptr type"); // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the @@ -2748,7 +2751,7 @@ if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); - if (auto localIdTy = getLocalid().getType().dyn_cast()) + if (auto localIdTy = llvm::dyn_cast(getLocalid().getType())) if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3) return emitOpError("localid is a vector and can be with only " " 2 or 3 components, actual number is ") @@ -2839,7 +2842,7 @@ } auto ptrType = spirv::PointerType::get(elementType, storageClass); - if (auto valVecTy = elementType.dyn_cast()) + if (auto valVecTy = llvm::dyn_cast(elementType)) ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass); if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) { @@ -2879,7 +2882,7 @@ } auto ptrType = spirv::PointerType::get(elementType, storageClass); - if (auto valVecTy = elementType.dyn_cast()) + if (auto valVecTy = llvm::dyn_cast(elementType)) ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass); if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc, @@ -3148,7 +3151,7 @@ void spirv::LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr, MemoryAccessAttr memoryAccess, IntegerAttr alignment) { - auto ptrType = basePtr.getType().cast(); + auto ptrType = llvm::cast(basePtr.getType()); build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess, alignment); } @@ -3177,7 +3180,7 @@ void spirv::LoadOp::print(OpAsmPrinter &printer) { SmallVector elidedAttrs; StringRef sc = stringifyStorageClass( - getPtr().getType().cast().getStorageClass()); + llvm::cast(getPtr().getType()).getStorageClass()); printer << " \"" << sc << "\" " << getPtr(); printMemoryAccessAttribute(*this, printer, elidedAttrs); @@ -3494,7 +3497,7 @@ } if (auto interface = entryPointOp.getInterface()) { for (Attribute varRef : interface) { - auto varSymRef = varRef.dyn_cast(); + auto varSymRef = llvm::dyn_cast(varRef); if (!varSymRef) { return entryPointOp.emitError( "expected symbol reference for interface " @@ -3587,8 +3590,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::SelectOp::verify() { - if (auto conditionTy = getCondition().getType().dyn_cast()) { - auto resultVectorTy = getResult().getType().dyn_cast(); + if (auto conditionTy = llvm::dyn_cast(getCondition().getType())) { + auto resultVectorTy = llvm::dyn_cast(getResult().getType()); if (!resultVectorTy) { return emitOpError("result expected to be of vector type when " "condition is of vector type"); @@ -3760,9 +3763,9 @@ return emitOpError("SpecId cannot be negative"); auto value = getDefaultValue(); - if (value.isa()) { + if (llvm::isa(value)) { // Make sure bitwidth is allowed. - if (!value.getType().isa()) + if (!llvm::isa(value.getType())) return emitOpError("default value bitwidth disallowed"); return success(); } @@ -3798,7 +3801,7 @@ void spirv::StoreOp::print(OpAsmPrinter &printer) { SmallVector elidedAttrs; StringRef sc = stringifyStorageClass( - getPtr().getType().cast().getStorageClass()); + llvm::cast(getPtr().getType()).getStorageClass()); printer << " \"" << sc << "\" " << getPtr() << ", " << getValue(); printMemoryAccessAttribute(*this, printer, elidedAttrs); @@ -3861,7 +3864,7 @@ if (parser.parseType(type)) return failure(); - auto ptrType = type.dyn_cast(); + auto ptrType = llvm::dyn_cast(type); if (!ptrType) return parser.emitError(loc, "expected spirv.ptr type"); result.addTypes(ptrType); @@ -3901,7 +3904,7 @@ "spirv.GlobalVariable for module-level variables."); } - auto pointerType = getPointer().getType().cast(); + auto pointerType = llvm::cast(getPointer().getType()); if (getStorageClass() != pointerType.getStorageClass()) return emitOpError( "storage class must match result pointer's storage class"); @@ -3940,7 +3943,7 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::VectorShuffleOp::verify() { - VectorType resultType = getType().cast(); + VectorType resultType = llvm::cast(getType()); size_t numResultElements = resultType.getNumElements(); if (numResultElements != getComponents().size()) @@ -3950,8 +3953,8 @@ << getComponents().size() << ")"; size_t totalSrcElements = - getVector1().getType().cast().getNumElements() + - getVector2().getType().cast().getNumElements(); + llvm::cast(getVector1().getType()).getNumElements() + + llvm::cast(getVector2().getType()).getNumElements(); for (const auto &selector : getComponents().getAsValueRange()) { uint32_t index = selector.getZExtValue(); @@ -4001,13 +4004,14 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer, Type coopMatrix) { - Type pointeeType = pointer.cast().getPointeeType(); - if (!pointeeType.isa() && !pointeeType.isa()) + Type pointeeType = llvm::cast(pointer).getPointeeType(); + if (!llvm::isa(pointeeType) && + !llvm::isa(pointeeType)) return op->emitError( "Pointer must point to a scalar or vector type but provided ") << pointeeType; spirv::StorageClass storage = - pointer.cast().getStorageClass(); + llvm::cast(pointer).getStorageClass(); if (storage != spirv::StorageClass::Workgroup && storage != spirv::StorageClass::StorageBuffer && storage != spirv::StorageClass::PhysicalStorageBuffer) @@ -4071,10 +4075,11 @@ verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) { if (op.getC().getType() != op.getResult().getType()) return op.emitOpError("result and third operand must have the same type"); - auto typeA = op.getA().getType().cast(); - auto typeB = op.getB().getType().cast(); - auto typeC = op.getC().getType().cast(); - auto typeR = op.getResult().getType().cast(); + auto typeA = llvm::cast(op.getA().getType()); + auto typeB = llvm::cast(op.getB().getType()); + auto typeC = llvm::cast(op.getC().getType()); + auto typeR = + llvm::cast(op.getResult().getType()); if (typeA.getRows() != typeR.getRows() || typeA.getColumns() != typeB.getRows() || typeB.getColumns() != typeR.getColumns()) @@ -4086,8 +4091,8 @@ auto elementTypeA = typeA.getElementType(); auto elementTypeB = typeB.getElementType(); if (isa(elementTypeA) && isa(elementTypeB)) { - if (elementTypeA.cast().getWidth() != - elementTypeB.cast().getWidth()) + if (llvm::cast(elementTypeA).getWidth() != + llvm::cast(elementTypeB).getWidth()) return op.emitOpError( "matrix A and B integer element types must be the same bit width"); } else if (elementTypeA != elementTypeB) { @@ -4105,13 +4110,14 @@ static LogicalResult verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) { - Type pointeeType = pointer.cast().getPointeeType(); - if (!pointeeType.isa() && !pointeeType.isa()) + Type pointeeType = llvm::cast(pointer).getPointeeType(); + if (!llvm::isa(pointeeType) && + !llvm::isa(pointeeType)) return op->emitError( "Pointer must point to a scalar or vector type but provided ") << pointeeType; spirv::StorageClass storage = - pointer.cast().getStorageClass(); + llvm::cast(pointer).getStorageClass(); if (storage != spirv::StorageClass::Workgroup && storage != spirv::StorageClass::CrossWorkgroup && storage != spirv::StorageClass::UniformConstant && @@ -4147,10 +4153,11 @@ static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) { if (op.getC().getType() != op.getResult().getType()) return op.emitOpError("result and third operand must have the same type"); - auto typeA = op.getA().getType().cast(); - auto typeB = op.getB().getType().cast(); - auto typeC = op.getC().getType().cast(); - auto typeR = op.getResult().getType().cast(); + auto typeA = llvm::cast(op.getA().getType()); + auto typeB = llvm::cast(op.getB().getType()); + auto typeC = llvm::cast(op.getC().getType()); + auto typeR = + llvm::cast(op.getResult().getType()); if (typeA.getRows() != typeR.getRows() || typeA.getColumns() != typeB.getRows() || typeB.getColumns() != typeR.getColumns()) @@ -4174,8 +4181,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::MatrixTimesScalarOp::verify() { - if (auto inputCoopmat = - getMatrix().getType().dyn_cast()) { + if (auto inputCoopmat = llvm::dyn_cast( + getMatrix().getType())) { if (inputCoopmat.getElementType() != getScalar().getType()) return emitError("input matrix components' type and scaling value must " "have the same type"); @@ -4183,7 +4190,7 @@ } // Check that the scalar type is the same as the matrix element type. - auto inputMatrix = getMatrix().getType().cast(); + auto inputMatrix = llvm::cast(getMatrix().getType()); if (getScalar().getType() != inputMatrix.getElementType()) return emitError("input matrix components' type and scaling value must " "have the same type"); @@ -4199,11 +4206,11 @@ printer << ' '; StringRef targetStorageClass = stringifyStorageClass( - getTarget().getType().cast().getStorageClass()); + llvm::cast(getTarget().getType()).getStorageClass()); printer << " \"" << targetStorageClass << "\" " << getTarget() << ", "; StringRef sourceStorageClass = stringifyStorageClass( - getSource().getType().cast().getStorageClass()); + llvm::cast(getSource().getType()).getStorageClass()); printer << " \"" << sourceStorageClass << "\" " << getSource(); SmallVector elidedAttrs; @@ -4215,7 +4222,7 @@ printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); Type pointeeType = - getTarget().getType().cast().getPointeeType(); + llvm::cast(getTarget().getType()).getPointeeType(); printer << " : " << pointeeType; } @@ -4263,10 +4270,10 @@ LogicalResult spirv::CopyMemoryOp::verify() { Type targetType = - getTarget().getType().cast().getPointeeType(); + llvm::cast(getTarget().getType()).getPointeeType(); Type sourceType = - getSource().getType().cast().getPointeeType(); + llvm::cast(getSource().getType()).getPointeeType(); if (targetType != sourceType) return emitOpError("both operands must be pointers to the same type"); @@ -4290,8 +4297,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::TransposeOp::verify() { - auto inputMatrix = getMatrix().getType().cast(); - auto resultMatrix = getResult().getType().cast(); + auto inputMatrix = llvm::cast(getMatrix().getType()); + auto resultMatrix = llvm::cast(getResult().getType()); // Verify that the input and output matrices have correct shapes. if (inputMatrix.getNumRows() != resultMatrix.getNumColumns()) @@ -4315,9 +4322,9 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::MatrixTimesMatrixOp::verify() { - auto leftMatrix = getLeftmatrix().getType().cast(); - auto rightMatrix = getRightmatrix().getType().cast(); - auto resultMatrix = getResult().getType().cast(); + auto leftMatrix = llvm::cast(getLeftmatrix().getType()); + auto rightMatrix = llvm::cast(getRightmatrix().getType()); + auto resultMatrix = llvm::cast(getResult().getType()); // left matrix columns' count and right matrix rows' count must be equal if (leftMatrix.getNumColumns() != rightMatrix.getNumRows()) @@ -4403,16 +4410,16 @@ } LogicalResult spirv::SpecConstantCompositeOp::verify() { - auto cType = getType().dyn_cast(); + auto cType = llvm::dyn_cast(getType()); auto constituents = this->getConstituents().getValue(); if (!cType) return emitError("result type must be a composite type, but provided ") << getType(); - if (cType.isa()) + if (llvm::isa(cType)) return emitError("unsupported composite type ") << cType; - if (cType.isa()) + if (llvm::isa(cType)) return emitError("unsupported composite type ") << cType; if (constituents.size() != cType.getNumElements()) return emitError("has incorrect number of operands: expected ") @@ -4420,7 +4427,7 @@ << constituents.size(); for (auto index : llvm::seq(0, constituents.size())) { - auto constituent = constituents[index].cast(); + auto constituent = llvm::cast(constituents[index]); auto constituentSpecConstOp = dyn_cast(SymbolTable::lookupNearestSymbolFrom( @@ -4498,19 +4505,19 @@ LogicalResult spirv::GLFrexpStructOp::verify() { spirv::StructType structTy = - getResult().getType().dyn_cast(); + llvm::dyn_cast(getResult().getType()); if (structTy.getNumElements() != 2) return emitError("result type must be a struct type with two memebers"); Type significandTy = structTy.getElementType(0); Type exponentTy = structTy.getElementType(1); - VectorType exponentVecTy = exponentTy.dyn_cast(); - IntegerType exponentIntTy = exponentTy.dyn_cast(); + VectorType exponentVecTy = llvm::dyn_cast(exponentTy); + IntegerType exponentIntTy = llvm::dyn_cast(exponentTy); Type operandTy = getOperand().getType(); - VectorType operandVecTy = operandTy.dyn_cast(); - FloatType operandFTy = operandTy.dyn_cast(); + VectorType operandVecTy = llvm::dyn_cast(operandTy); + FloatType operandFTy = llvm::dyn_cast(operandTy); if (significandTy != operandTy) return emitError("member zero of the resulting struct type must be the " @@ -4518,7 +4525,7 @@ if (exponentVecTy) { IntegerType componentIntTy = - exponentVecTy.getElementType().dyn_cast(); + llvm::dyn_cast(exponentVecTy.getElementType()); if (!componentIntTy || componentIntTy.getWidth() != 32) return emitError("member one of the resulting struct type must" "be a scalar or vector of 32 bit integer type"); @@ -4547,11 +4554,12 @@ Type significandType = getX().getType(); Type exponentType = getExp().getType(); - if (significandType.isa() != exponentType.isa()) + if (llvm::isa(significandType) != + llvm::isa(exponentType)) return emitOpError("operands must both be scalars or vectors"); auto getNumElements = [](Type type) -> unsigned { - if (auto vectorType = type.dyn_cast()) + if (auto vectorType = llvm::dyn_cast(type)) return vectorType.getNumElements(); return 1; }; @@ -4567,17 +4575,19 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::ImageDrefGatherOp::verify() { - VectorType resultType = getResult().getType().cast(); + VectorType resultType = llvm::cast(getResult().getType()); auto sampledImageType = - getSampledimage().getType().cast(); - auto imageType = sampledImageType.getImageType().cast(); + llvm::cast(getSampledimage().getType()); + auto imageType = + llvm::cast(sampledImageType.getImageType()); if (resultType.getNumElements() != 4) return emitOpError("result type must be a vector of four components"); Type elementType = resultType.getElementType(); Type sampledElementType = imageType.getElementType(); - if (!sampledElementType.isa() && elementType != sampledElementType) + if (!llvm::isa(sampledElementType) && + elementType != sampledElementType) return emitOpError( "the component type of result must be the same as sampled type of the " "underlying image type"); @@ -4629,7 +4639,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::ImageQuerySizeOp::verify() { - spirv::ImageType imageType = getImage().getType().cast(); + spirv::ImageType imageType = + llvm::cast(getImage().getType()); Type resultType = getResult().getType(); spirv::Dim dim = imageType.getDim(); @@ -4677,7 +4688,7 @@ componentNumber += 1; unsigned resultComponentNumber = 1; - if (auto resultVectorType = resultType.dyn_cast()) + if (auto resultVectorType = llvm::dyn_cast(resultType)) resultComponentNumber = resultVectorType.getNumElements(); if (componentNumber != resultComponentNumber) @@ -4798,7 +4809,7 @@ LogicalResult spirv::VectorTimesScalarOp::verify() { if (getVector().getType() != getType()) return emitOpError("vector operand and result type mismatch"); - auto scalarType = getType().cast().getElementType(); + auto scalarType = llvm::cast(getType()).getElementType(); if (getScalar().getType() != scalarType) return emitOpError("scalar operand and result element type match"); return success(); @@ -4851,11 +4862,11 @@ return op->emitOpError("requires the same type for both vector operands"); unsigned expectedNumAttrs = 0; - if (auto intTy = factorTy.dyn_cast()) { + if (auto intTy = llvm::dyn_cast(factorTy)) { ++expectedNumAttrs; auto packedVectorFormat = - op->getAttr(kPackedVectorFormatAttrName) - .dyn_cast_or_null(); + llvm::dyn_cast_or_null( + op->getAttr(kPackedVectorFormatAttrName)); if (!packedVectorFormat) return op->emitOpError("requires Packed Vector Format attribute for " "integer vector operands"); @@ -4927,9 +4938,9 @@ SmallVector, 1> capabilities = {dotProductCap}; Type factorTy = op->getOperand(0).getType(); - if (auto intTy = factorTy.dyn_cast()) { - auto formatAttr = op->getAttr(kPackedVectorFormatAttrName) - .cast(); + if (auto intTy = llvm::dyn_cast(factorTy)) { + auto formatAttr = llvm::cast( + op->getAttr(kPackedVectorFormatAttrName)); if (formatAttr.getValue() == spirv::PackedVectorFormat::PackedVectorFormat4x8Bit) capabilities.push_back(dotProductInput4x8BitPackedCap); @@ -4937,7 +4948,7 @@ return capabilities; } - auto vecTy = factorTy.cast(); + auto vecTy = llvm::cast(factorTy); if (vecTy.getElementTypeBitWidth() == 8) { capabilities.push_back(dotProductInput4x8BitCap); return capabilities; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -68,17 +68,18 @@ void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional storage) { - getElementType().cast().getExtensions(extensions, storage); + llvm::cast(getElementType()).getExtensions(extensions, storage); } void ArrayType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage) { - getElementType().cast().getCapabilities(capabilities, storage); + llvm::cast(getElementType()) + .getCapabilities(capabilities, storage); } std::optional ArrayType::getSizeInBytes() { - auto elementType = getElementType().cast(); + auto elementType = llvm::cast(getElementType()); std::optional size = elementType.getSizeInBytes(); if (!size) return std::nullopt; @@ -90,11 +91,11 @@ //===----------------------------------------------------------------------===// bool CompositeType::classof(Type type) { - if (auto vectorType = type.dyn_cast()) + if (auto vectorType = llvm::dyn_cast(type)) return isValid(vectorType); - return type.isa(); + return llvm::isa(type); } bool CompositeType::isValid(VectorType type) { @@ -108,7 +109,7 @@ default: return false; } - return type.getRank() == 1 && type.getElementType().isa(); + return type.getRank() == 1 && llvm::isa(type.getElementType()); } Type CompositeType::getElementType(unsigned index) const { @@ -160,8 +161,8 @@ MatrixType, RuntimeArrayType, StructType>( [&](auto type) { type.getExtensions(extensions, storage); }) .Case([&](VectorType type) { - return type.getElementType().cast().getExtensions( - extensions, storage); + return llvm::cast(type.getElementType()) + .getExtensions(extensions, storage); }) .Default([](Type) { llvm_unreachable("invalid composite type"); }); } @@ -180,8 +181,8 @@ ArrayRef ref(caps, std::size(caps)); capabilities.push_back(ref); } - return type.getElementType().cast().getCapabilities( - capabilities, storage); + return llvm::cast(type.getElementType()) + .getCapabilities(capabilities, storage); }) .Default([](Type) { llvm_unreachable("invalid composite type"); }); } @@ -193,7 +194,7 @@ return structType.getSizeInBytes(); if (auto vectorType = dyn_cast()) { std::optional elementSize = - vectorType.getElementType().cast().getSizeInBytes(); + llvm::cast(vectorType.getElementType()).getSizeInBytes(); if (!elementSize) return std::nullopt; return *elementSize * vectorType.getNumElements(); @@ -249,7 +250,7 @@ void CooperativeMatrixNVType::getExtensions( SPIRVType::ExtensionArrayRefVector &extensions, std::optional storage) { - getElementType().cast().getExtensions(extensions, storage); + llvm::cast(getElementType()).getExtensions(extensions, storage); static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix}; ArrayRef ref(exts, std::size(exts)); extensions.push_back(ref); @@ -258,7 +259,8 @@ void CooperativeMatrixNVType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage) { - getElementType().cast().getCapabilities(capabilities, storage); + llvm::cast(getElementType()) + .getCapabilities(capabilities, storage); static const Capability caps[] = {Capability::CooperativeMatrixNV}; ArrayRef ref(caps, std::size(caps)); capabilities.push_back(ref); @@ -317,7 +319,7 @@ void JointMatrixINTELType::getExtensions( SPIRVType::ExtensionArrayRefVector &extensions, std::optional storage) { - getElementType().cast().getExtensions(extensions, storage); + llvm::cast(getElementType()).getExtensions(extensions, storage); static const Extension exts[] = {Extension::SPV_INTEL_joint_matrix}; ArrayRef ref(exts, std::size(exts)); extensions.push_back(ref); @@ -326,7 +328,8 @@ void JointMatrixINTELType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage) { - getElementType().cast().getCapabilities(capabilities, storage); + llvm::cast(getElementType()) + .getCapabilities(capabilities, storage); static const Capability caps[] = {Capability::JointMatrixINTEL}; ArrayRef ref(caps, std::size(caps)); capabilities.push_back(ref); @@ -489,8 +492,8 @@ std::optional storage) { // Use this pointer type's storage class because this pointer indicates we are // using the pointee type in that specific storage class. - getPointeeType().cast().getExtensions(extensions, - getStorageClass()); + llvm::cast(getPointeeType()) + .getExtensions(extensions, getStorageClass()); if (auto scExts = spirv::getExtensions(getStorageClass())) extensions.push_back(*scExts); @@ -501,8 +504,8 @@ std::optional storage) { // Use this pointer type's storage class because this pointer indicates we are // using the pointee type in that specific storage class. - getPointeeType().cast().getCapabilities(capabilities, - getStorageClass()); + llvm::cast(getPointeeType()) + .getCapabilities(capabilities, getStorageClass()); if (auto scCaps = spirv::getCapabilities(getStorageClass())) capabilities.push_back(*scCaps); @@ -547,7 +550,7 @@ void RuntimeArrayType::getExtensions( SPIRVType::ExtensionArrayRefVector &extensions, std::optional storage) { - getElementType().cast().getExtensions(extensions, storage); + llvm::cast(getElementType()).getExtensions(extensions, storage); } void RuntimeArrayType::getCapabilities( @@ -558,7 +561,8 @@ ArrayRef ref(caps, std::size(caps)); capabilities.push_back(ref); } - getElementType().cast().getCapabilities(capabilities, storage); + llvm::cast(getElementType()) + .getCapabilities(capabilities, storage); } //===----------------------------------------------------------------------===// @@ -566,10 +570,10 @@ //===----------------------------------------------------------------------===// bool ScalarType::classof(Type type) { - if (auto floatType = type.dyn_cast()) { + if (auto floatType = llvm::dyn_cast(type)) { return isValid(floatType); } - if (auto intType = type.dyn_cast()) { + if (auto intType = llvm::dyn_cast(type)) { return isValid(intType); } return false; @@ -723,9 +727,9 @@ // Allow SPIR-V dialect types if (llvm::isa(type.getDialect())) return true; - if (type.isa()) + if (llvm::isa(type)) return true; - if (auto vectorType = type.dyn_cast()) + if (auto vectorType = llvm::dyn_cast(type)) return CompositeType::isValid(vectorType); return false; } @@ -815,7 +819,7 @@ LogicalResult SampledImageType::verify(function_ref emitError, Type imageType) { - if (!imageType.isa()) + if (!llvm::isa(imageType)) return emitError() << "expected image type"; return success(); @@ -824,13 +828,13 @@ void SampledImageType::getExtensions( SPIRVType::ExtensionArrayRefVector &extensions, std::optional storage) { - getImageType().cast().getExtensions(extensions, storage); + llvm::cast(getImageType()).getExtensions(extensions, storage); } void SampledImageType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage) { - getImageType().cast().getCapabilities(capabilities, storage); + llvm::cast(getImageType()).getCapabilities(capabilities, storage); } //===----------------------------------------------------------------------===// @@ -1125,14 +1129,14 @@ void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional storage) { for (Type elementType : getElementTypes()) - elementType.cast().getExtensions(extensions, storage); + llvm::cast(elementType).getExtensions(extensions, storage); } void StructType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage) { for (Type elementType : getElementTypes()) - elementType.cast().getCapabilities(capabilities, storage); + llvm::cast(elementType).getCapabilities(capabilities, storage); } llvm::hash_code spirv::hash_value( @@ -1186,7 +1190,7 @@ return emitError() << "matrix columns must be vectors of floats"; /// The underlying vectors (columns) must be of size 2, 3, or 4 - ArrayRef columnShape = columnType.cast().getShape(); + ArrayRef columnShape = llvm::cast(columnType).getShape(); if (columnShape.size() != 1) return emitError() << "matrix columns must be 1D vectors"; @@ -1198,8 +1202,8 @@ /// Returns true if the matrix elements are vectors of float elements bool MatrixType::isValidColumnType(Type columnType) { - if (auto vectorType = columnType.dyn_cast()) { - if (vectorType.getElementType().isa()) + if (auto vectorType = llvm::dyn_cast(columnType)) { + if (llvm::isa(vectorType.getElementType())) return true; } return false; @@ -1208,13 +1212,13 @@ Type MatrixType::getColumnType() const { return getImpl()->columnType; } Type MatrixType::getElementType() const { - return getImpl()->columnType.cast().getElementType(); + return llvm::cast(getImpl()->columnType).getElementType(); } unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; } unsigned MatrixType::getNumRows() const { - return getImpl()->columnType.cast().getShape()[0]; + return llvm::cast(getImpl()->columnType).getShape()[0]; } unsigned MatrixType::getNumElements() const { @@ -1223,7 +1227,7 @@ void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional storage) { - getColumnType().cast().getExtensions(extensions, storage); + llvm::cast(getColumnType()).getExtensions(extensions, storage); } void MatrixType::getCapabilities( @@ -1235,7 +1239,7 @@ capabilities.push_back(ref); } // Add any capabilities associated with the underlying vectors (i.e., columns) - getColumnType().cast().getCapabilities(capabilities, storage); + llvm::cast(getColumnType()).getCapabilities(capabilities, storage); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -730,11 +730,10 @@ // block statically used per shader entry point." So we should always reuse // the existing one. if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) { - auto numElements = - cast( - ptrType.getPointeeType().cast().getElementType( - 0)) - .getNumElements(); + auto numElements = cast(llvm::cast( + ptrType.getPointeeType()) + .getElementType(0)) + .getNumElements(); if (numElements == elementCount) return varOp; } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -41,14 +41,14 @@ } bool shape::isExtentTensorType(Type type) { - auto ranked = type.dyn_cast(); + auto ranked = llvm::dyn_cast(type); return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex(); } LogicalResult shape::getShapeVec(Value input, SmallVectorImpl &shapeValues) { if (auto inputOp = input.getDefiningOp()) { - auto type = inputOp.getArg().getType().cast(); + auto type = llvm::cast(inputOp.getArg().getType()); if (!type.hasRank()) return failure(); llvm::append_range(shapeValues, type.getShape()); @@ -64,7 +64,7 @@ static bool isErrorPropagationPossible(TypeRange operandTypes) { return llvm::any_of(operandTypes, [](Type ty) { - return ty.isa(); + return llvm::isa(ty); }); } @@ -72,7 +72,7 @@ assert(op != nullptr && op->getNumResults() == 1); Type resultTy = op->getResultTypes().front(); if (isErrorPropagationPossible(op->getOperandTypes())) { - if (!resultTy.isa()) + if (!llvm::isa(resultTy)) return op->emitOpError() << "if at least one of the operands can hold error values then " "the result must be of type `size` to propagate them"; @@ -84,7 +84,7 @@ assert(op != nullptr && op->getNumResults() == 1); Type resultTy = op->getResultTypes().front(); if (isErrorPropagationPossible(op->getOperandTypes())) { - if (!resultTy.isa()) + if (!llvm::isa(resultTy)) return op->emitOpError() << "if at least one of the operands can hold error values then " "the result must be of type `shape` to propagate them"; @@ -94,7 +94,7 @@ template static bool eachHasOnlyOneOfTypes(TypeRange typeRange) { - return typeRange.size() == 1 && typeRange.front().isa(); + return typeRange.size() == 1 && llvm::isa(typeRange.front()); } template @@ -147,13 +147,15 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - if (type.isa() || isExtentTensorType(type)) - return builder.create(loc, type, - value.cast()); - if (type.isa()) - return builder.create(loc, type, value.cast()); - if (type.isa()) - return builder.create(loc, type, value.cast()); + if (llvm::isa(type) || isExtentTensorType(type)) + return builder.create( + loc, type, llvm::cast(value)); + if (llvm::isa(type)) + return builder.create(loc, type, + llvm::cast(value)); + if (llvm::isa(type)) + return builder.create(loc, type, + llvm::cast(value)); return arith::ConstantOp::materialize(builder, value, type, loc); } @@ -165,7 +167,7 @@ return op->emitError( "shape.lib attribute may only be on op implementing SymbolTable"); - if (auto symbolRef = attribute.getValue().dyn_cast()) { + if (auto symbolRef = llvm::dyn_cast(attribute.getValue())) { auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); if (!symbol) return op->emitError("shape function library ") @@ -176,17 +178,17 @@ << symbolRef << " required to be shape function library"; } - if (auto arr = attribute.getValue().dyn_cast()) { + if (auto arr = llvm::dyn_cast(attribute.getValue())) { // Verify all entries are function libraries and mappings in libraries // refer to unique ops. DenseSet key; for (auto it : arr) { - if (!it.isa()) + if (!llvm::isa(it)) return op->emitError( "only SymbolRefAttr allowed in shape.lib attribute array"); auto shapeFnLib = dyn_cast( - SymbolTable::lookupSymbolIn(op, it.cast())); + SymbolTable::lookupSymbolIn(op, llvm::cast(it))); if (!shapeFnLib) return op->emitError() << it << " does not refer to FunctionLibraryOp"; @@ -395,8 +397,8 @@ MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - if (operands[0].getType().isa() || - operands[1].getType().isa()) + if (llvm::isa(operands[0].getType()) || + llvm::isa(operands[1].getType())) inferredReturnTypes.assign({SizeType::get(context)}); else inferredReturnTypes.assign({IndexType::get(context)}); @@ -617,7 +619,7 @@ getOperation()->eraseOperand(idx); // Always false if any input is statically known false - if (!a.cast().getValue()) + if (!llvm::cast(a).getValue()) return a; } // If this is reached, all inputs were statically known passing. @@ -651,9 +653,11 @@ if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1]) return nullptr; auto lhsShape = llvm::to_vector<6>( - adaptor.getShapes()[0].cast().getValues()); + llvm::cast(adaptor.getShapes()[0]) + .getValues()); auto rhsShape = llvm::to_vector<6>( - adaptor.getShapes()[1].cast().getValues()); + llvm::cast(adaptor.getShapes()[1]) + .getValues()); SmallVector resultShape; // If the shapes are not compatible, we can't fold it. @@ -677,7 +681,8 @@ LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { auto isPotentiallyNonEmptyShape = [](Value shape) { - if (auto extentTensorTy = shape.getType().dyn_cast()) { + if (auto extentTensorTy = + llvm::dyn_cast(shape.getType())) { if (extentTensorTy.getDimSize(0) == 0) return false; } @@ -714,11 +719,11 @@ // Insert cast if needed. if (replacement.getType() != op.getType()) { auto loc = op.getLoc(); - if (op.getType().isa()) { + if (llvm::isa(op.getType())) { replacement = rewriter.create(loc, replacement); } else { - assert(!op.getType().isa() && - !replacement.getType().isa() && + assert(!llvm::isa(op.getType()) && + !llvm::isa(replacement.getType()) && "expect extent tensor cast"); replacement = rewriter.create(loc, op.getType(), replacement); @@ -781,7 +786,7 @@ if (auto castOp = operand.getDefiningOp()) { // Only eliminate the cast if it holds no shape information. bool isInformationLoosingCast = - castOp.getType().cast().isDynamicDim(0); + llvm::cast(castOp.getType()).isDynamicDim(0); if (isInformationLoosingCast) { anyChange = true; return castOp.getSource(); @@ -807,14 +812,15 @@ LogicalResult matchAndRewrite(BroadcastOp op, PatternRewriter &rewriter) const override { // Only concretize dynamic extent tensor result types. - auto resultTy = op.getType().dyn_cast(); + auto resultTy = llvm::dyn_cast(op.getType()); if (!resultTy || !resultTy.isDynamicDim(0)) return failure(); // Infer resulting shape rank if possible. int64_t maxRank = 0; for (Value shape : op.getShapes()) { - if (auto extentTensorTy = shape.getType().dyn_cast()) { + if (auto extentTensorTy = + llvm::dyn_cast(shape.getType())) { // Cannot infer resulting shape rank if any operand is dynamically // ranked. if (extentTensorTy.isDynamicDim(0)) @@ -883,12 +889,12 @@ NamedAttrList dummy; if (parser.parseAttribute(extentsRaw, "dummy", dummy)) return failure(); - auto extentsArray = extentsRaw.dyn_cast(); + auto extentsArray = llvm::dyn_cast(extentsRaw); if (!extentsArray) return failure(); SmallVector ints; for (Attribute extent : extentsArray) { - IntegerAttr attr = extent.dyn_cast(); + IntegerAttr attr = llvm::dyn_cast(extent); if (!attr) return failure(); ints.push_back(attr.getInt()); @@ -930,7 +936,7 @@ Type lhs = l.front(); Type rhs = r.front(); - if (lhs.isa() || rhs.isa()) + if (llvm::isa(lhs) || llvm::isa(rhs)) // Shape type is compatible with all other valid return types. return true; return lhs == rhs; @@ -956,7 +962,7 @@ static bool hasAtMostSingleNonScalar(ArrayRef attributes) { bool nonScalarSeen = false; for (Attribute a : attributes) { - if (!a || a.cast().getNumElements() != 0) { + if (!a || llvm::cast(a).getNumElements() != 0) { if (nonScalarSeen) return false; nonScalarSeen = true; @@ -1070,13 +1076,13 @@ if (auto constSizeOp = getIndex().getDefiningOp()) return constSizeOp.getValue().getLimitedValue(); if (auto constantOp = getIndex().getDefiningOp()) - return constantOp.getValue().cast().getInt(); + return llvm::cast(constantOp.getValue()).getInt(); return std::nullopt; } OpFoldResult DimOp::fold(FoldAdaptor adaptor) { Type valType = getValue().getType(); - auto valShapedType = valType.dyn_cast(); + auto valShapedType = llvm::dyn_cast(valType); if (!valShapedType || !valShapedType.hasRank()) return nullptr; std::optional index = getConstantIndex(); @@ -1104,7 +1110,7 @@ } LogicalResult mlir::shape::DimOp::verify() { - auto st = getValue().getType().cast(); + auto st = llvm::cast(getValue().getType()); if (!st.hasRank()) return success(); if (auto index = getConstantIndex()) { @@ -1142,8 +1148,8 @@ MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - if (operands[0].getType().isa() || - operands[1].getType().isa()) + if (llvm::isa(operands[0].getType()) || + llvm::isa(operands[1].getType())) inferredReturnTypes.assign({SizeType::get(context)}); else inferredReturnTypes.assign({IndexType::get(context)}); @@ -1199,7 +1205,7 @@ return nullptr; SmallVector extents; for (auto attr : adaptor.getExtents()) - extents.push_back(attr.cast().getInt()); + extents.push_back(llvm::cast(attr).getInt()); Builder builder(getContext()); return builder.getIndexTensorAttr(extents); } @@ -1215,9 +1221,8 @@ } FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { - auto attr = getMapping() - .get(op->getName().getIdentifier()) - .dyn_cast_or_null(); + auto attr = llvm::dyn_cast_or_null( + getMapping().get(op->getName().getIdentifier())); if (!attr) return nullptr; return lookupSymbol(attr); @@ -1329,7 +1334,7 @@ if (auto constSizeOp = getDim().getDefiningOp()) return constSizeOp.getValue().getLimitedValue(); if (auto constantOp = getDim().getDefiningOp()) - return constantOp.getValue().cast().getInt(); + return llvm::cast(constantOp.getValue()).getInt(); return std::nullopt; } @@ -1349,7 +1354,7 @@ int64_t dim) { auto loc = result.location; auto dimAttr = builder.getIndexAttr(dim); - if (shape.getType().isa()) { + if (llvm::isa(shape.getType())) { Value dim = builder.create(loc, dimAttr); build(builder, result, builder.getType(), shape, dim); } else { @@ -1405,7 +1410,7 @@ return failure(); auto isShapeType = [](Type arg) { - if (arg.isa()) + if (llvm::isa(arg)) return true; return isExtentTensorType(arg); }; @@ -1414,29 +1419,29 @@ Type acc = types.front(); for (auto t : drop_begin(types)) { Type l = acc, r = t; - if (!l.isa()) + if (!llvm::isa(l)) std::swap(l, r); // Handle sizes, propagate error type if present. - if (l.isa()) { - if (r.isa()) + if (llvm::isa(l)) { + if (llvm::isa(r)) acc = l; else return emitOptionalError(location, "requires all sizes or shapes"); - } else if (l.isa()) { - if (r.isa()) + } else if (llvm::isa(l)) { + if (llvm::isa(r)) acc = r; else return emitOptionalError(location, "requires all sizes or shapes"); - } else if (l.isa()) { + } else if (llvm::isa(l)) { // Handle shapes, propagate error type if present. if (isShapeType(r)) acc = l; else return emitOptionalError(location, "requires all sizes or shapes"); } else if (isExtentTensorType(l)) { - auto rank1 = l.cast().getShape()[0]; - auto rank2 = r.cast().getShape()[0]; + auto rank1 = llvm::cast(l).getShape()[0]; + auto rank2 = llvm::cast(r).getShape()[0]; if (ShapedType::isDynamic(rank1)) acc = l; else if (ShapedType::isDynamic(rank2)) @@ -1460,13 +1465,13 @@ Type lhs = l.front(); Type rhs = r.front(); - if (!lhs.isa()) + if (!llvm::isa(lhs)) std::swap(lhs, rhs); - if (lhs.isa()) - return rhs.isa(); - if (lhs.isa()) - return rhs.isa(); + if (llvm::isa(lhs)) + return llvm::isa(rhs); + if (llvm::isa(lhs)) + return llvm::isa(rhs); if (succeeded(verifyCompatibleShapes({lhs, rhs}))) return true; @@ -1511,14 +1516,14 @@ if (!shapeOfOp) return failure(); auto rankedTensorType = - shapeOfOp.getArg().getType().dyn_cast(); + llvm::dyn_cast(shapeOfOp.getArg().getType()); if (!rankedTensorType) return failure(); int64_t rank = rankedTensorType.getRank(); - if (op.getType().isa()) { + if (llvm::isa(op.getType())) { rewriter.replaceOpWithNewOp(op.getOperation(), rank); - } else if (op.getType().isa()) { + } else if (llvm::isa(op.getType())) { rewriter.replaceOpWithNewOp(op.getOperation(), rank); } else { return failure(); @@ -1537,7 +1542,7 @@ MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - if (operands[0].getType().isa()) + if (llvm::isa(operands[0].getType())) inferredReturnTypes.assign({SizeType::get(context)}); else inferredReturnTypes.assign({IndexType::get(context)}); @@ -1563,7 +1568,7 @@ return {}; APInt product(64, 1); - for (auto value : shape.cast()) + for (auto value : llvm::cast(shape)) product *= value; Builder builder(getContext()); return builder.getIndexAttr(product.getLimitedValue()); @@ -1573,7 +1578,7 @@ MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - if (operands[0].getType().isa()) + if (llvm::isa(operands[0].getType())) inferredReturnTypes.assign({SizeType::get(context)}); else inferredReturnTypes.assign({IndexType::get(context)}); @@ -1615,9 +1620,9 @@ bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { if (l.size() != 1 || r.size() != 1) return false; - if (l.front().isa() && r.front().isa()) + if (llvm::isa(l.front()) && llvm::isa(r.front())) return true; - if (l.front().isa() && r.front().isa()) + if (llvm::isa(l.front()) && llvm::isa(r.front())) return true; return false; } @@ -1647,9 +1652,9 @@ bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { if (l.size() != 1 || r.size() != 1) return false; - if (l.front().isa() && r.front().isa()) + if (llvm::isa(l.front()) && llvm::isa(r.front())) return true; - if (l.front().isa() && r.front().isa()) + if (llvm::isa(l.front()) && llvm::isa(r.front())) return true; return false; } @@ -1674,8 +1679,8 @@ MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - if (operands[0].getType().isa() || - operands[1].getType().isa()) + if (llvm::isa(operands[0].getType()) || + llvm::isa(operands[1].getType())) inferredReturnTypes.assign({SizeType::get(context)}); else inferredReturnTypes.assign({IndexType::get(context)}); @@ -1694,7 +1699,7 @@ //===----------------------------------------------------------------------===// OpFoldResult ShapeOfOp::fold(FoldAdaptor) { - auto type = getOperand().getType().dyn_cast(); + auto type = llvm::dyn_cast(getOperand().getType()); if (!type || !type.hasStaticShape()) return nullptr; Builder builder(getContext()); @@ -1707,9 +1712,9 @@ LogicalResult matchAndRewrite(shape::ShapeOfOp op, PatternRewriter &rewriter) const override { - if (!op.getArg().getType().isa()) + if (!llvm::isa(op.getArg().getType())) return failure(); - if (op.getType().isa()) + if (llvm::isa(op.getType())) return failure(); rewriter.replaceOpWithNewOp(op.getOperation(), @@ -1732,7 +1737,7 @@ LogicalResult matchAndRewrite(tensor::CastOp op, PatternRewriter &rewriter) const override { - auto ty = op.getType().dyn_cast(); + auto ty = llvm::dyn_cast(op.getType()); if (!ty || ty.getRank() != 1) return failure(); @@ -1741,7 +1746,7 @@ return failure(); // Argument type must be ranked and must not conflict. - auto argTy = shapeOfOp.getArg().getType().dyn_cast(); + auto argTy = llvm::dyn_cast(shapeOfOp.getArg().getType()); if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) return failure(); @@ -1761,10 +1766,10 @@ MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - if (operands[0].getType().isa()) + if (llvm::isa(operands[0].getType())) inferredReturnTypes.assign({ShapeType::get(context)}); else { - auto shapedTy = operands[0].getType().cast(); + auto shapedTy = llvm::cast(operands[0].getType()); int64_t rank = shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic; Type indexTy = IndexType::get(context); @@ -1783,10 +1788,11 @@ Type lhs = l.front(); Type rhs = r.front(); - if (!lhs.isa() || !rhs.isa()) + if (!llvm::isa(lhs) || + !llvm::isa(rhs)) return false; - if (lhs.isa() || rhs.isa()) + if (llvm::isa(lhs) || llvm::isa(rhs)) // Shape type is compatible with all other valid return types. return true; @@ -1819,7 +1825,8 @@ bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { if (inputs.size() != 1 || outputs.size() != 1) return false; - return inputs[0].isa() && outputs[0].isa(); + return llvm::isa(inputs[0]) && + llvm::isa(outputs[0]); } //===----------------------------------------------------------------------===// @@ -1884,16 +1891,16 @@ bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { if (inputs.size() != 1 || outputs.size() != 1) return false; - if (auto inputTensor = inputs[0].dyn_cast()) { - if (!inputTensor.getElementType().isa() || + if (auto inputTensor = llvm::dyn_cast(inputs[0])) { + if (!llvm::isa(inputTensor.getElementType()) || inputTensor.getRank() != 1) return false; - } else if (!inputs[0].isa()) { + } else if (!llvm::isa(inputs[0])) { return false; } - TensorType outputTensor = outputs[0].dyn_cast(); - return outputTensor && outputTensor.getElementType().isa(); + TensorType outputTensor = llvm::dyn_cast(outputs[0]); + return outputTensor && llvm::isa(outputTensor.getElementType()); } //===----------------------------------------------------------------------===// @@ -1911,7 +1918,7 @@ bodyBlock.addArgument(builder.getIndexType(), result.location); Type elementType; - if (auto tensorType = shape.getType().dyn_cast()) + if (auto tensorType = llvm::dyn_cast(shape.getType())) elementType = tensorType.getElementType(); else elementType = SizeType::get(builder.getContext()); @@ -1934,7 +1941,7 @@ << blockArgsCount << " arguments"; // The first block argument is the index and must always be of type `index`. - if (!block.getArgument(0).getType().isa()) + if (!llvm::isa(block.getArgument(0).getType())) return emitOpError( "argument 0 of ReduceOp body is expected to be of IndexType"); @@ -1942,12 +1949,12 @@ // `index`, depending on whether the reduce operation is applied to a shape or // to an extent tensor. Type extentTy = block.getArgument(1).getType(); - if (getShape().getType().isa()) { - if (!extentTy.isa()) + if (llvm::isa(getShape().getType())) { + if (!llvm::isa(extentTy)) return emitOpError("argument 1 of ReduceOp body is expected to be of " "SizeType if the ReduceOp operates on a ShapeType"); } else { - if (!extentTy.isa()) + if (!llvm::isa(extentTy)) return emitOpError( "argument 1 of ReduceOp body is expected to be of IndexType if the " "ReduceOp operates on an extent tensor"); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -261,10 +261,10 @@ if (attrName == "dimLevelType") { Attribute attr; RETURN_ON_FAIL(parser.parseAttribute(attr)); - auto arrayAttr = attr.dyn_cast(); + auto arrayAttr = llvm::dyn_cast(attr); ERROR_IF(!arrayAttr, "expected an array for dimension level types") for (auto i : arrayAttr) { - auto strAttr = i.dyn_cast(); + auto strAttr = llvm::dyn_cast(i); ERROR_IF(!strAttr, "expected a string value in dimension level types") auto strVal = strAttr.getValue(); if (auto optDLT = parseDLT(strVal)) { @@ -279,25 +279,25 @@ } else if (attrName == "dimOrdering") { Attribute attr; RETURN_ON_FAIL(parser.parseAttribute(attr)) - auto affineAttr = attr.dyn_cast(); + auto affineAttr = llvm::dyn_cast(attr); ERROR_IF(!affineAttr, "expected an affine map for dimension ordering") dimOrd = affineAttr.getValue(); } else if (attrName == "higherOrdering") { Attribute attr; RETURN_ON_FAIL(parser.parseAttribute(attr)) - auto affineAttr = attr.dyn_cast(); + auto affineAttr = llvm::dyn_cast(attr); ERROR_IF(!affineAttr, "expected an affine map for higher ordering") higherOrd = affineAttr.getValue(); } else if (attrName == "posWidth") { Attribute attr; RETURN_ON_FAIL(parser.parseAttribute(attr)) - auto intAttr = attr.dyn_cast(); + auto intAttr = llvm::dyn_cast(attr); ERROR_IF(!intAttr, "expected an integral position bitwidth") posWidth = intAttr.getInt(); } else if (attrName == "crdWidth") { Attribute attr; RETURN_ON_FAIL(parser.parseAttribute(attr)) - auto intAttr = attr.dyn_cast(); + auto intAttr = llvm::dyn_cast(attr); ERROR_IF(!intAttr, "expected an integral index bitwidth") crdWidth = intAttr.getInt(); } else if (attrName == "slice") { @@ -305,7 +305,7 @@ // Dispatches to DimSliceAttr to skip mnemonic bool finished = false; while (auto attr = SparseTensorDimSliceAttr::parse(parser, nullptr)) { - auto sliceAttr = attr.cast(); + auto sliceAttr = llvm::cast(attr); slices.push_back(sliceAttr); if (parser.parseOptionalComma().failed()) { finished = true; @@ -442,9 +442,9 @@ SparseTensorEncodingAttr mlir::sparse_tensor::getSparseTensorEncoding(Type type) { - if (auto ttp = type.dyn_cast()) - return ttp.getEncoding().dyn_cast_or_null(); - if (auto mdtp = type.dyn_cast()) + if (auto ttp = llvm::dyn_cast(type)) + return llvm::dyn_cast_or_null(ttp.getEncoding()); + if (auto mdtp = llvm::dyn_cast(type)) return mdtp.getEncoding(); return nullptr; } @@ -725,12 +725,12 @@ } LogicalResult ConvertOp::verify() { - if (auto tp1 = getSource().getType().dyn_cast()) { - if (auto tp2 = getDest().getType().dyn_cast()) { + if (auto tp1 = llvm::dyn_cast(getSource().getType())) { + if (auto tp2 = llvm::dyn_cast(getDest().getType())) { if (tp1.getRank() != tp2.getRank()) return emitError("unexpected conversion mismatch in rank"); auto dstEnc = - tp2.getEncoding().dyn_cast_or_null(); + llvm::dyn_cast_or_null(tp2.getEncoding()); if (dstEnc && dstEnc.isSlice()) return emitError("cannot convert to a sparse tensor slice"); diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -41,7 +41,7 @@ options.unknownTypeConverterFn = [](Value value, Attribute memorySpace, const BufferizationOptions &options) { return getMemRefTypeWithStaticIdentityLayout( - value.getType().cast(), memorySpace); + llvm::cast(value.getType()), memorySpace); }; if (analysisOnly) { options.testAnalysisOnly = true; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -260,7 +260,7 @@ /// `IntegerType`), this also works for `RankedTensorType` and `VectorType` /// (for which it generates a constant `DenseElementsAttr` of zeros). inline Value constantZero(OpBuilder &builder, Location loc, Type tp) { - if (auto ctp = tp.dyn_cast()) { + if (auto ctp = llvm::dyn_cast(tp)) { auto zeroe = builder.getZeroAttr(ctp.getElementType()); auto zeroa = builder.getArrayAttr({zeroe, zeroe}); return builder.create(loc, tp, zeroa); @@ -271,7 +271,7 @@ /// Generates a 1-valued constant of the given type. This supports all /// the same types as `constantZero`. inline Value constantOne(OpBuilder &builder, Location loc, Type tp) { - if (auto ctp = tp.dyn_cast()) { + if (auto ctp = llvm::dyn_cast(tp)) { auto zeroe = builder.getZeroAttr(ctp.getElementType()); auto onee = getOneAttr(builder, ctp.getElementType()); auto zeroa = builder.getArrayAttr({onee, zeroe}); @@ -350,7 +350,7 @@ } inline bool isZeroRankedTensorOrScalar(Type type) { - auto rtp = type.dyn_cast(); + auto rtp = llvm::dyn_cast(type); return !rtp || rtp.getRank() == 0; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -28,7 +28,7 @@ static std::optional> genSplitSparseConstant(OpBuilder &builder, Location loc, Value tensor) { if (auto constOp = tensor.getDefiningOp()) { - if (auto a = constOp.getValue().dyn_cast()) { + if (auto a = llvm::dyn_cast(constOp.getValue())) { auto coordinates = builder.create(loc, a.getIndices()); auto values = builder.create(loc, a.getValues()); return std::make_pair(coordinates, values); @@ -94,7 +94,7 @@ OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) { if (tp.isIndex()) return OverheadType::kIndex; - if (auto intTp = tp.dyn_cast()) + if (auto intTp = llvm::dyn_cast(tp)) return overheadTypeEncoding(intTp.getWidth()); llvm_unreachable("Unknown overhead type"); } @@ -169,7 +169,7 @@ return PrimaryType::kI16; if (elemTp.isInteger(8)) return PrimaryType::kI8; - if (auto complexTp = elemTp.dyn_cast()) { + if (auto complexTp = llvm::dyn_cast(elemTp)) { auto complexEltTp = complexTp.getElementType(); if (complexEltTp.isF64()) return PrimaryType::kC64; @@ -205,10 +205,10 @@ return value; // int <=> index - if (srcTp.isa() || dstTp.isa()) + if (llvm::isa(srcTp) || llvm::isa(dstTp)) return builder.create(loc, dstTp, value); - const auto srcIntTp = srcTp.dyn_cast_or_null(); + const auto srcIntTp = llvm::dyn_cast_or_null(srcTp); const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false; return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast); } @@ -216,7 +216,7 @@ Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem, Value s) { Value load = builder.create(loc, mem, s); - if (!load.getType().isa()) { + if (!llvm::isa(load.getType())) { if (load.getType().getIntOrFloatBitWidth() < 64) load = builder.create(loc, builder.getI64Type(), load); load = @@ -226,14 +226,14 @@ } mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { - if (tp.isa()) + if (llvm::isa(tp)) return builder.getFloatAttr(tp, 1.0); - if (tp.isa()) + if (llvm::isa(tp)) return builder.getIndexAttr(1); - if (auto intTp = tp.dyn_cast()) + if (auto intTp = llvm::dyn_cast(tp)) return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1)); - if (tp.isa()) { - auto shapedTp = tp.cast(); + if (llvm::isa(tp)) { + auto shapedTp = llvm::cast(tp); if (auto one = getOneAttr(builder, shapedTp.getElementType())) return DenseElementsAttr::get(shapedTp, one); } @@ -244,13 +244,13 @@ Value v) { Type tp = v.getType(); Value zero = constantZero(builder, loc, tp); - if (tp.isa()) + if (llvm::isa(tp)) return builder.create(loc, arith::CmpFPredicate::UNE, v, zero); if (tp.isIntOrIndex()) return builder.create(loc, arith::CmpIPredicate::ne, v, zero); - if (tp.dyn_cast()) + if (llvm::dyn_cast(tp)) return builder.create(loc, v, zero); llvm_unreachable("Non-numeric type"); } @@ -580,12 +580,12 @@ } // Remap value. Value val; - if (attr.getElementType().isa()) { - auto valAttr = elems[i].second.cast(); + if (llvm::isa(attr.getElementType())) { + auto valAttr = llvm::cast(elems[i].second); val = builder.create(loc, attr.getElementType(), valAttr); } else { - auto valAttr = elems[i].second.cast(); + auto valAttr = llvm::cast(elems[i].second); val = builder.create(loc, valAttr); } assert(val); @@ -597,7 +597,7 @@ size_t size, Value mem, size_t offsetIdx, Value offsetVal) { #ifndef NDEBUG - const auto memTp = mem.getType().cast(); + const auto memTp = llvm::cast(mem.getType()); assert(memTp.getRank() == 1); const DynSize memSh = memTp.getDimSize(0); assert(ShapedType::isDynamic(memSh) || memSh >= static_cast(size)); @@ -619,7 +619,7 @@ ValueRange vs, size_t offsetIdx, Value offsetVal) { #ifndef NDEBUG const size_t vsize = vs.size(); - const auto memTp = mem.getType().cast(); + const auto memTp = llvm::cast(mem.getType()); assert(memTp.getRank() == 1); const DynSize memSh = memTp.getDimSize(0); assert(ShapedType::isDynamic(memSh) || memSh >= static_cast(vsize)); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -350,7 +350,7 @@ // on positions. for (TensorId t = 0, numTensors = getNumTensors(); t < numTensors; t++) { const Value tensor = tensors[t]; - const auto rtp = tensor.getType().dyn_cast(); + const auto rtp = llvm::dyn_cast(tensor.getType()); if (!rtp) // Skips only scalar, zero ranked tensor still need to be bufferized and // (probably) filled with zeros by users. @@ -432,7 +432,7 @@ Type indexType = builder.getIndexType(); Value c0 = constantZero(builder, loc, indexType); for (TensorId t = 0, e = tensors.size(); t < e; t++) { - auto rtp = tensors[t].getType().dyn_cast(); + auto rtp = llvm::dyn_cast(tensors[t].getType()); if (!rtp) continue; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -100,7 +100,7 @@ /// completion. Needs to cast the buffer to a unranked buffer. static Value genHostRegisterMemref(OpBuilder &builder, Location loc, Value mem) { - MemRefType memTp = mem.getType().cast(); + MemRefType memTp = llvm::cast(mem.getType()); UnrankedMemRefType resTp = UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0); Value cast = builder.create(loc, resTp, mem); @@ -133,7 +133,7 @@ /// that feature does not seem to be fully supported yet. static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem, Value token) { - auto tp = mem.getType().cast(); + auto tp = llvm::cast(mem.getType()); auto elemTp = tp.getElementType(); auto shape = tp.getShape(); auto memTp = MemRefType::get(shape, elemTp); @@ -304,7 +304,7 @@ for (OpOperand &o : op->getOpOperands()) { Value val = o.get(); Block *block; - if (auto arg = val.dyn_cast()) + if (auto arg = llvm::dyn_cast(val)) block = arg.getOwner(); else block = val.getDefiningOp()->getBlock(); @@ -321,7 +321,7 @@ Type tp = val.getType(); if (val.getDefiningOp()) constants.push_back(val); - else if (tp.isa() || tp.isIntOrIndex()) + else if (llvm::isa(tp) || tp.isIntOrIndex()) scalars.push_back(val); else if (isa(tp)) buffers.push_back(val); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp @@ -111,9 +111,9 @@ Value metaData = builder.create(loc, structType); SpecifierStructBuilder md(metaData); if (!source) { - auto memSizeArrayType = structType.cast() - .getBody()[kMemSizePosInSpecifier] - .cast(); + auto memSizeArrayType = + llvm::cast(structType.cast() + .getBody()[kMemSizePosInSpecifier]); Value zero = constantZero(builder, loc, memSizeArrayType.getElementType()); // Fill memSizes array with zero. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -80,7 +80,7 @@ Value idx) { idx = genCast(builder, loc, idx, builder.getIndexType()); val = genCast(builder, loc, val, - mem.getType().cast().getElementType()); + llvm::cast(mem.getType()).getElementType()); builder.create(loc, val, mem, idx); } @@ -253,7 +253,7 @@ case SparseTensorFieldKind::CrdMemRef: case SparseTensorFieldKind::ValMemRef: field = createAllocation( - builder, loc, fType.cast(), + builder, loc, llvm::cast(fType), (fKind == SparseTensorFieldKind::PosMemRef) ? posHeuristic : (fKind == SparseTensorFieldKind::CrdMemRef) ? crdHeuristic : valHeuristic, @@ -779,7 +779,7 @@ fields.reserve(desc.getNumFields()); // Memcpy on memref fields. for (auto field : desc.getMemRefFields()) { - auto memrefTp = field.getType().cast(); + auto memrefTp = llvm::cast(field.getType()); auto size = rewriter.create(loc, field, 0); auto copied = rewriter.create(loc, memrefTp, ValueRange{size}); @@ -1128,7 +1128,8 @@ auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource()); SmallVector fields; foreachFieldAndTypeInSparseTensor( - SparseTensorType(op.getResult().getType().cast()), + SparseTensorType( + llvm::cast(op.getResult().getType())), [&rewriter, &fields, srcDesc, loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl, DimLevelType /*dlt*/) -> bool { @@ -1143,7 +1144,7 @@ // values. Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0); auto dstMem = rewriter.create( - loc, fTp.cast(), sz); + loc, llvm::cast(fTp), sz); if (fTp != srcMem.getType()) { // Converts elements type. scf::buildLoopNest( @@ -1397,7 +1398,7 @@ } assert(field); - if (auto memrefTp = field.getType().dyn_cast(); + if (auto memrefTp = llvm::dyn_cast(field.getType()); memrefTp && memrefTp.getRank() > 1) { ReassociationIndices reassociation; for (int i = 0, e = memrefTp.getRank(); i < e; i++) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -399,7 +399,7 @@ /// (which can be either dim- or lvl-coords, depending on context). static Value genGetNextCall(OpBuilder &builder, Location loc, Value iter, Value coords, Value elemPtr) { - Type elemTp = elemPtr.getType().cast().getElementType(); + Type elemTp = llvm::cast(elemPtr.getType()).getElementType(); SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)}; SmallVector params{iter, coords, elemPtr}; Type i1 = builder.getI1Type(); @@ -1045,7 +1045,7 @@ matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resTp = op.getType(); - Type posTp = resTp.cast().getElementType(); + Type posTp = llvm::cast(resTp).getElementType(); SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)}; Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel()); replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl}, @@ -1064,7 +1064,7 @@ ConversionPatternRewriter &rewriter) const override { // TODO: use `SparseTensorType::getCrdType` instead. Type resType = op.getType(); - const Type crdTp = resType.cast().getElementType(); + const Type crdTp = llvm::cast(resType).getElementType(); SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)}; Location loc = op->getLoc(); @@ -1096,7 +1096,7 @@ LogicalResult matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto resType = op.getType().cast(); + auto resType = llvm::cast(op.getType()); rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType, adaptor.getOperands())); return success(); @@ -1113,7 +1113,8 @@ ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); // Query values array size for the actually stored values size. - Type eltType = op.getTensor().getType().cast().getElementType(); + Type eltType = + llvm::cast(op.getTensor().getType()).getElementType(); auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType); Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands()); rewriter.replaceOpWithNewOp(op, values, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -79,7 +79,7 @@ // Helper to detect chain of multiplications that do not involve x. static bool isMulChain(Value val, Value x) { - if (auto arg = val.dyn_cast()) + if (auto arg = llvm::dyn_cast(val)) return arg != x; if (auto *def = val.getDefiningOp()) { if (isa(def) || isa(def)) @@ -105,7 +105,7 @@ // Helper to detect direct yield of a zero value. static bool isZeroYield(GenericOp op) { auto yieldOp = cast(op.getRegion().front().getTerminator()); - if (auto arg = yieldOp.getOperand(0).dyn_cast()) { + if (auto arg = llvm::dyn_cast(yieldOp.getOperand(0))) { if (arg.getOwner()->getParentOp() == op) { return isZeroValue(op->getOperand(arg.getArgNumber())); } @@ -719,7 +719,7 @@ bool fromSparseConst = false; if (auto constOp = op.getSource().getDefiningOp()) { - if (constOp.getValue().dyn_cast()) { + if (llvm::dyn_cast(constOp.getValue())) { fromSparseConst = true; } } @@ -974,7 +974,7 @@ // Special-case: for each over a sparse constant uses its own rewriting // rule. if (auto constOp = input.getDefiningOp()) { - if (auto attr = constOp.getValue().dyn_cast()) { + if (auto attr = llvm::dyn_cast(constOp.getValue())) { return genForeachOnSparseConstant(op, rewriter, attr); } } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -112,7 +112,7 @@ Value mem, ArrayRef idxs, Value vmask) { VectorType vtp = vectorType(vl, mem); Value pass = constantZero(rewriter, loc, vtp); - if (idxs.back().getType().isa()) { + if (llvm::isa(idxs.back().getType())) { SmallVector scalarArgs(idxs.begin(), idxs.end()); Value indexVec = idxs.back(); scalarArgs.back() = constantIndex(rewriter, loc, 0); @@ -129,7 +129,7 @@ /// the last index, i.e. back(). static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem, ArrayRef idxs, Value vmask, Value rhs) { - if (idxs.back().getType().isa()) { + if (llvm::isa(idxs.back().getType())) { SmallVector scalarArgs(idxs.begin(), idxs.end()); Value indexVec = idxs.back(); scalarArgs.back() = constantIndex(rewriter, loc, 0); @@ -260,7 +260,7 @@ // innermost loop simply pass through as well. // Example: // a[i][j] for both i and j - if (auto arg = sub.dyn_cast()) { + if (auto arg = llvm::dyn_cast(sub)) { if (isInvariantArg(arg, block) == innermost) return false; if (codegen) @@ -298,8 +298,8 @@ Location loc = forOp.getLoc(); Value vload = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask); - Type etp = vload.getType().cast().getElementType(); - if (!etp.isa()) { + Type etp = llvm::cast(vload.getType()).getElementType(); + if (!llvm::isa(etp)) { if (etp.getIntOrFloatBitWidth() < 32) vload = rewriter.create( loc, vectorType(vl, rewriter.getI32Type()), vload); @@ -318,7 +318,7 @@ Value inv = load.getOperand(0); Value idx = load.getOperand(1); if (isInvariantValue(inv, block)) { - if (auto arg = idx.dyn_cast()) { + if (auto arg = llvm::dyn_cast(idx)) { if (isInvariantArg(arg, block) || !innermost) return false; if (codegen) @@ -369,7 +369,7 @@ if (!VectorType::isValidElementType(exp.getType())) return false; // A block argument is invariant/reduction/index. - if (auto arg = exp.dyn_cast()) { + if (auto arg = llvm::dyn_cast(exp)) { if (arg == forOp.getInductionVar()) { // We encountered a single, innermost index inside the computation, // such as a[i] = i, which must convert to [i, i+1, ...]. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -88,9 +88,9 @@ // Overrides method from AffineExprVisitor. void visitDimExpr(AffineDimExpr expr) { if (pickedDim == nullptr || - pickIterType == iterTypes[expr.getPosition()] - .cast() - .getValue()) { + pickIterType == + llvm::cast(iterTypes[expr.getPosition()]) + .getValue()) { pickedDim = expr; } } @@ -344,7 +344,7 @@ // we can't use `getRankedTensorType`/`getSparseTensorType` here. // However, we don't need to handle `StorageSpecifierType`, so we // can use `SparseTensorType` once we guard against non-tensors. - const auto rtp = tensor.getType().dyn_cast(); + const auto rtp = llvm::dyn_cast(tensor.getType()); if (!rtp) return 0; const SparseTensorType stt(rtp); @@ -1243,7 +1243,7 @@ Location loc = op.getLoc(); if (atStart) { auto dynShape = {ShapedType::kDynamic}; - Type etp = tensor.getType().cast().getElementType(); + Type etp = llvm::cast(tensor.getType()).getElementType(); Type t1 = MemRefType::get(dynShape, etp); Type t2 = MemRefType::get(dynShape, builder.getI1Type()); Type t3 = MemRefType::get(dynShape, builder.getIndexType()); @@ -1833,7 +1833,7 @@ // required for sparse tensor slice rank reducing too. Level maxLvlRank = 0; for (auto operand : op.getOperands()) { - if (auto rtp = operand.getType().dyn_cast()) { + if (auto rtp = llvm::dyn_cast(operand.getType())) { maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank()); } } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp @@ -130,7 +130,8 @@ ArrayRef dstStaticShape, ArrayRef reassocation) { return dstStaticShape.size() > - static_cast(src.getType().cast().getRank()) + static_cast( + llvm::cast(src.getType()).getRank()) ? getExpandedOutputShapeFromInputShape( builder, loc, src, dstStaticShape, reassocation) : getCollapsedOutputShapeFromInputShape( @@ -185,7 +186,7 @@ return; } int64_t staticValue = - valueOrAttr.get().cast().getInt(); + llvm::cast(valueOrAttr.get()).getInt(); expr = expr + staticValue; }; addOpFoldResult(lowPad[dim]); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -42,13 +42,13 @@ return op; if (complex::ConstantOp::isBuildableWith(value, type)) return builder.create(loc, type, - value.cast()); + llvm::cast(value)); return nullptr; } SmallVector tensor::getMixedSizes(OpBuilder &builder, Location loc, Value value) { - auto tensorType = value.getType().cast(); + auto tensorType = llvm::cast(value.getType()); SmallVector result; for (int64_t i = 0; i < tensorType.getRank(); ++i) { if (tensorType.isDynamicDim(i)) { @@ -63,7 +63,7 @@ FailureOr tensor::getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult) { - auto tensorType = opResult.getType().dyn_cast(); + auto tensorType = llvm::dyn_cast(opResult.getType()); assert(tensorType && "expected tensor type"); // If the op has a destination, it implements DestinationStyleOpInterface and @@ -100,7 +100,7 @@ Operation *op, SmallVector &result) { for (OpResult opResult : op->getResults()) { - if (opResult.getType().isa()) { + if (llvm::isa(opResult.getType())) { FailureOr destination = getOrCreateDestination(b, loc, opResult); if (failed(destination)) return failure(); @@ -111,8 +111,8 @@ } bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) { - if (auto rtp1 = tp1.dyn_cast()) { - if (auto rtp2 = tp2.dyn_cast()) + if (auto rtp1 = llvm::dyn_cast(tp1)) { + if (auto rtp2 = llvm::dyn_cast(tp2)) return rtp1.getShape() == rtp2.getShape() && rtp1.getElementType() == rtp2.getElementType(); return false; @@ -131,7 +131,7 @@ // Rank-reduced dims must have a static unit dimension. bool isStaticUnitSize = size.value().is() && - size.value().get().cast().getInt() == 1; + llvm::cast(size.value().get()).getInt() == 1; if (shapePos == static_cast(reducedShape.size())) { // There are no more dims in the reduced shape. All remaining sizes must @@ -220,8 +220,8 @@ /// Returns true if `target` is a ranked tensor type that preserves static /// information available in the `source` ranked tensor type. bool mlir::tensor::preservesStaticInformation(Type source, Type target) { - auto sourceType = source.dyn_cast(); - auto targetType = target.dyn_cast(); + auto sourceType = llvm::dyn_cast(source); + auto targetType = llvm::dyn_cast(target); // Requires RankedTensorType. if (!sourceType || !targetType) @@ -322,8 +322,8 @@ if (inputs.size() != 1 || outputs.size() != 1) return false; Type a = inputs.front(), b = outputs.front(); - auto aT = a.dyn_cast(); - auto bT = b.dyn_cast(); + auto aT = llvm::dyn_cast(a); + auto bT = llvm::dyn_cast(b); if (!aT || !bT) return false; @@ -380,9 +380,9 @@ return failure(); auto sourceType = - tensorCastOperand.getOperand().getType().cast(); - auto intermediateType = tensorCastOperand.getType().cast(); - auto resultType = tensorCast.getType().cast(); + llvm::cast(tensorCastOperand.getOperand().getType()); + auto intermediateType = llvm::cast(tensorCastOperand.getType()); + auto resultType = llvm::cast(tensorCast.getType()); // We can remove the intermediate cast if joining all three produces the // same result as just joining the source and result shapes. @@ -427,15 +427,15 @@ tensorCast.getOperand().getDefiningOp(); // Cannot fold cast to unranked tensor. - auto rankedResultType = tensorCast.getType().dyn_cast(); + auto rankedResultType = + llvm::dyn_cast(tensorCast.getType()); if (!rankedResultType) return failure(); if (!extractOperand || !canFoldIntoProducerOp(tensorCast) || - rankedResultType.getShape() == tensorCast.getSource() - .getType() - .cast() - .getShape()) + rankedResultType.getShape() == + llvm::cast(tensorCast.getSource().getType()) + .getShape()) return failure(); SmallVector sizes = extractOperand.getMixedSizes(); @@ -506,7 +506,7 @@ return {}; // Folding for unranked types (UnrankedTensorType) is not supported. - auto tensorType = getSource().getType().dyn_cast(); + auto tensorType = llvm::dyn_cast(getSource().getType()); if (!tensorType) return {}; @@ -527,7 +527,7 @@ // Fold dim to the operand of tensor.generate. if (auto fromElements = dyn_cast_or_null(definingOp)) { auto resultType = - fromElements.getResult().getType().cast(); + llvm::cast(fromElements.getResult().getType()); // The case where the type encodes the size of the dimension is handled // above. assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()])); @@ -751,7 +751,8 @@ if (!producer) return failure(); - auto resultType = castOp->getResult(0).getType().cast(); + auto resultType = + llvm::cast(castOp->getResult(0).getType()); ArrayRef resultShape = resultType.getShape(); SmallVector currMixedSizes = producer.getMixedSizes(); SmallVector newMixedSizes; @@ -765,7 +766,7 @@ // result dim matches. if (auto attr = currDim.dyn_cast()) { if (ShapedType::isDynamic(newDim) || - newDim != attr.cast().getInt()) { + newDim != llvm::cast(attr).getInt()) { // Something is off, the cast result shape cannot be more dynamic // than the empty tensor result shape (enforced by // `canFoldIntoProducer`). Abort for now. @@ -826,7 +827,7 @@ auto tensorCast = extract.getTensor().getDefiningOp(); if (!tensorCast) return failure(); - if (!tensorCast.getSource().getType().isa()) + if (!llvm::isa(tensorCast.getSource().getType())) return failure(); rewriter.replaceOpWithNewOp( extract, tensorCast.getSource(), extract.getIndices()); @@ -843,7 +844,7 @@ LogicalResult ExtractOp::verify() { // Verify the # indices match if we have a ranked type. - auto tensorType = getTensor().getType().cast(); + auto tensorType = llvm::cast(getTensor().getType()); if (tensorType.getRank() != static_cast(getIndices().size())) return emitOpError("incorrect number of indices for extract_element"); return success(); @@ -853,20 +854,20 @@ // If this is a splat elements attribute, simply return the value. All of // the elements of a splat attribute are the same. if (Attribute tensor = adaptor.getTensor()) - if (auto splatTensor = tensor.dyn_cast()) + if (auto splatTensor = llvm::dyn_cast(tensor)) return splatTensor.getSplatValue(); // Collect the constant indices into the tensor. SmallVector indices; for (Attribute indice : adaptor.getIndices()) { - if (!indice || !indice.isa()) + if (!indice || !llvm::isa(indice)) return {}; - indices.push_back(indice.cast().getInt()); + indices.push_back(llvm::cast(indice).getInt()); } // Fold extract(from_elements(...)). if (auto fromElementsOp = getTensor().getDefiningOp()) { - auto tensorType = fromElementsOp.getType().cast(); + auto tensorType = llvm::cast(fromElementsOp.getType()); auto rank = tensorType.getRank(); assert(static_cast(indices.size()) == tensorType.getRank() && "rank mismatch"); @@ -887,7 +888,7 @@ // If this is an elements attribute, query the value at the given indices. if (Attribute tensor = adaptor.getTensor()) { - auto elementsAttr = tensor.dyn_cast(); + auto elementsAttr = llvm::dyn_cast(tensor); if (elementsAttr && elementsAttr.isValidIndex(indices)) return elementsAttr.getValues()[indices]; } @@ -1070,7 +1071,7 @@ LogicalResult InsertOp::verify() { // Verify the # indices match if we have a ranked type. - auto destType = getDest().getType().cast(); + auto destType = llvm::cast(getDest().getType()); if (destType.getRank() != static_cast(getIndices().size())) return emitOpError("incorrect number of indices"); return success(); @@ -1080,7 +1081,7 @@ Attribute scalar = adaptor.getScalar(); Attribute dest = adaptor.getDest(); if (scalar && dest) - if (auto splatDest = dest.dyn_cast()) + if (auto splatDest = llvm::dyn_cast(dest)) if (scalar == splatDest.getSplatValue()) return dest; return {}; @@ -1113,7 +1114,7 @@ LogicalResult GenerateOp::verify() { // Ensure that the tensor type has as many dynamic dimensions as are // specified by the operands. - RankedTensorType resultTy = getType().cast(); + RankedTensorType resultTy = llvm::cast(getType()); if (getNumOperands() != resultTy.getNumDynamicDims()) return emitError("must have as many index operands as dynamic extents " "in the result type"); @@ -1122,7 +1123,7 @@ } LogicalResult GenerateOp::verifyRegions() { - RankedTensorType resultTy = getType().cast(); + RankedTensorType resultTy = llvm::cast(getType()); // Ensure that region arguments span the index space. if (!llvm::all_of(getBody().getArgumentTypes(), [](Type ty) { return ty.isIndex(); })) @@ -1150,7 +1151,7 @@ // Build and populate body. OpBuilder::InsertionGuard guard(b); Region *bodyRegion = result.regions.front().get(); - auto rank = resultTy.cast().getRank(); + auto rank = llvm::cast(resultTy).getRank(); SmallVector argumentTypes(rank, b.getIndexType()); SmallVector argumentLocs(rank, result.location); Block *bodyBlock = @@ -1170,7 +1171,7 @@ LogicalResult matchAndRewrite(GenerateOp tensorFromElements, PatternRewriter &rewriter) const final { auto resultType = - tensorFromElements.getResult().getType().cast(); + llvm::cast(tensorFromElements.getResult().getType()); if (resultType.hasStaticShape()) return failure(); @@ -1261,7 +1262,7 @@ OpFoldResult RankOp::fold(FoldAdaptor adaptor) { // Constant fold rank when the rank of the operand is known. auto type = getOperand().getType(); - auto shapedType = type.dyn_cast(); + auto shapedType = llvm::dyn_cast(type); if (shapedType && shapedType.hasRank()) return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); return IntegerAttr(); @@ -1284,17 +1285,17 @@ } LogicalResult ReshapeOp::verify() { - TensorType operandType = getSource().getType().cast(); - TensorType resultType = getResult().getType().cast(); + TensorType operandType = llvm::cast(getSource().getType()); + TensorType resultType = llvm::cast(getResult().getType()); if (operandType.getElementType() != resultType.getElementType()) return emitOpError("element types of source and destination tensor " "types should be the same"); int64_t shapeSize = - getShape().getType().cast().getDimSize(0); - auto resultRankedType = resultType.dyn_cast(); - auto operandRankedType = operandType.dyn_cast(); + llvm::cast(getShape().getType()).getDimSize(0); + auto resultRankedType = llvm::dyn_cast(resultType); + auto operandRankedType = llvm::dyn_cast(operandType); if (resultRankedType) { if (operandRankedType && resultRankedType.hasStaticShape() && @@ -1392,7 +1393,7 @@ ArrayRef reassociation, ArrayRef attrs) { auto resultType = inferCollapsedType( - src.getType().cast(), + llvm::cast(src.getType()), getSymbolLessAffineMaps( convertReassociationIndicesToExprs(b.getContext(), reassociation))); build(b, result, resultType, src, attrs); @@ -1488,7 +1489,7 @@ if (!fromElements) return failure(); - auto shapedTy = reshapeOp.getType().template cast(); + auto shapedTy = llvm::cast(reshapeOp.getType()); if (!shapedTy.hasStaticShape()) return failure(); @@ -1510,7 +1511,7 @@ return failure(); RankedTensorType srcType = - castOp.getSource().getType().cast(); + llvm::cast(castOp.getSource().getType()); RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType( srcType, collapseShapeOp.getReassociationMaps()); @@ -1693,9 +1694,8 @@ ArrayRef offsets, ArrayRef sizes, ArrayRef strides) { // Type inferred in the absence of rank-reducing behavior. - auto inferredType = - inferResultType(sourceRankedTensorType, offsets, sizes, strides) - .cast(); + auto inferredType = llvm::cast( + inferResultType(sourceRankedTensorType, offsets, sizes, strides)); int rankDiff = inferredType.getRank() - desiredResultRank; if (rankDiff > 0) { auto shape = inferredType.getShape(); @@ -1739,13 +1739,11 @@ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - auto sourceRankedTensorType = source.getType().cast(); + auto sourceRankedTensorType = llvm::cast(source.getType()); // Structuring implementation this way avoids duplication between builders. if (!resultType) { - resultType = - ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets, - staticSizes, staticStrides) - .cast(); + resultType = llvm::cast(ExtractSliceOp::inferResultType( + sourceRankedTensorType, staticOffsets, staticSizes, staticStrides)); } build(b, result, resultType, source, dynamicOffsets, dynamicSizes, dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets), @@ -1831,7 +1829,7 @@ FailureOr ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value, ArrayRef desiredShape) { - auto sourceTensorType = value.getType().dyn_cast(); + auto sourceTensorType = llvm::dyn_cast(value.getType()); assert(sourceTensorType && "not a ranked tensor type"); auto sourceShape = sourceTensorType.getShape(); if (sourceShape.equals(desiredShape)) @@ -1968,8 +1966,8 @@ return failure(); // Dynamic result shape is not supported. - auto sourceType = op.getSource().getType().cast(); - auto resultType = op.getResult().getType().cast(); + auto sourceType = llvm::cast(op.getSource().getType()); + auto resultType = llvm::cast(op.getResult().getType()); if (!sourceType.hasStaticShape() || !resultType.hasStaticShape()) return failure(); @@ -2004,13 +2002,13 @@ // New attribute constructed by the sliced values. DenseElementsAttr newAttr; - if (auto elems = attr.dyn_cast()) { + if (auto elems = llvm::dyn_cast(attr)) { SmallVector outValues; outValues.reserve(sourceType.getNumElements()); sliceElements( elems.begin(), counts, offsets, sizes, strides, &outValues); newAttr = DenseElementsAttr::get(resultType, outValues); - } else if (auto elems = attr.dyn_cast()) { + } else if (auto elems = llvm::dyn_cast(attr)) { SmallVector outValues; outValues.reserve(sourceType.getNumElements()); sliceElements( @@ -2109,7 +2107,7 @@ OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) { if (auto splat = adaptor.getSource().dyn_cast_or_null()) { - auto resultType = getResult().getType().cast(); + auto resultType = llvm::cast(getResult().getType()); if (resultType.hasStaticShape()) return splat.resizeSplat(resultType); } @@ -2124,7 +2122,7 @@ Value mlir::tensor::createCanonicalRankReducingExtractSliceOp( OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) { - auto rankedTensorType = tensor.getType().cast(); + auto rankedTensorType = llvm::cast(tensor.getType()); unsigned rank = rankedTensorType.getRank(); SmallVector offsets(rank, b.getIndexAttr(0)); SmallVector sizes = getMixedSizes(b, loc, tensor); @@ -2372,8 +2370,8 @@ auto src = (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource()); auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest()); - auto srcType = src.getType().template dyn_cast(); - auto dstType = dst.getType().template dyn_cast(); + auto srcType = llvm::dyn_cast(src.getType()); + auto dstType = llvm::dyn_cast(dst.getType()); if (!srcType || !dstType) return failure(); if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(), @@ -2482,7 +2480,7 @@ Location loc, Value tensor, Value dest) { - auto rankedTensorType = dest.getType().cast(); + auto rankedTensorType = llvm::cast(dest.getType()); unsigned rank = rankedTensorType.getRank(); SmallVector offsets(rank, b.getIndexAttr(0)); SmallVector sizes = getMixedSizes(b, loc, dest); @@ -2514,8 +2512,8 @@ } LogicalResult PadOp::verify() { - auto sourceType = getSource().getType().cast(); - auto resultType = getResult().getType().cast(); + auto sourceType = llvm::cast(getSource().getType()); + auto resultType = llvm::cast(getResult().getType()); auto expectedType = PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh()); if (!expectedType) { @@ -2542,7 +2540,7 @@ LogicalResult PadOp::verifyRegions() { auto ®ion = getRegion(); - unsigned rank = getResult().getType().cast().getRank(); + unsigned rank = llvm::cast(getResult().getType()).getRank(); Block &block = region.front(); if (block.getNumArguments() != rank) return emitError("expected the block to have ") << rank << " arguments"; @@ -2557,7 +2555,7 @@ // Ensure that the region yields an element of the right type. auto yieldOp = llvm::cast(block.getTerminator()); if (yieldOp.getValue().getType() != - getType().cast().getElementType()) + llvm::cast(getType()).getElementType()) return emitOpError("expected yield type to match shape element type"); return success(); @@ -2597,7 +2595,7 @@ Value source, ArrayRef staticLow, ArrayRef staticHigh, ValueRange low, ValueRange high, bool nofold, ArrayRef attrs) { - auto sourceType = source.getType().cast(); + auto sourceType = llvm::cast(source.getType()); if (!resultType) resultType = inferResultType(sourceType, staticLow, staticHigh); build(b, result, resultType, source, low, high, @@ -2609,7 +2607,7 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType, Value source, ValueRange low, ValueRange high, bool nofold, ArrayRef attrs) { - auto sourceType = source.getType().cast(); + auto sourceType = llvm::cast(source.getType()); unsigned rank = sourceType.getRank(); SmallVector staticVector(rank, ShapedType::kDynamic); build(b, result, resultType, source, staticVector, staticVector, low, high, @@ -2620,7 +2618,7 @@ Value source, ArrayRef low, ArrayRef high, bool nofold, ArrayRef attrs) { - auto sourceType = source.getType().cast(); + auto sourceType = llvm::cast(source.getType()); SmallVector dynamicLow, dynamicHigh; SmallVector staticLow, staticHigh; // staticLow and staticHigh have full information of the padding config. @@ -2632,7 +2630,7 @@ if (!resultType) { resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh); } - assert(resultType.isa()); + assert(llvm::isa(resultType)); build(b, result, resultType, source, dynamicLow, dynamicHigh, b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr()); @@ -2647,7 +2645,7 @@ // Add a region and a block to yield the pad value. Region *region = result.regions[0].get(); - int sourceRank = source.getType().cast().getRank(); + int sourceRank = llvm::cast(source.getType()).getRank(); SmallVector blockArgTypes(sourceRank, b.getIndexType()); SmallVector blockArgLocs(sourceRank, result.location); @@ -2700,7 +2698,7 @@ return failure(); auto newResultType = PadOp::inferResultType( - castOp.getSource().getType().cast(), + llvm::cast(castOp.getSource().getType()), padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(), padTensorOp.getResultType().getShape()); @@ -2919,9 +2917,9 @@ LogicalResult matchAndRewrite(PadOp padTensorOp, PatternRewriter &rewriter) const override { Value input = padTensorOp.getSource(); - if (!input.getType().isa()) + if (!llvm::isa(input.getType())) return failure(); - auto inputDims = input.getType().cast().getShape(); + auto inputDims = llvm::cast(input.getType()).getShape(); auto inputRank = inputDims.size(); auto oldResultType = @@ -3240,7 +3238,7 @@ "applies to only pack or unpack operations"); int64_t destRank = op.getDestRank(); reifiedReturnShapes.resize(1, SmallVector(destRank)); - ShapedType resultType = op.getResult().getType().template cast(); + ShapedType resultType = llvm::cast(op.getResult().getType()); for (auto dim : llvm::seq(0, destRank)) { if (resultType.isDynamicDim(dim)) { reifiedReturnShapes[0][dim] = @@ -3655,8 +3653,8 @@ }; SmallVector mixedSizes; - for (auto [index, value] : - llvm::enumerate(source.getType().cast().getShape())) { + for (auto [index, value] : llvm::enumerate( + llvm::cast(source.getType()).getShape())) { if (ShapedType::isDynamic(value)) mixedSizes.push_back(b.create(loc, source, index).getResult()); else @@ -3671,7 +3669,7 @@ applyPermutationToVector(mixedSizes, outerDimsPerm); mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end()); - auto elemType = source.getType().cast().getElementType(); + auto elemType = llvm::cast(source.getType()).getElementType(); return b.create(loc, mixedSizes, elemType); } @@ -3789,7 +3787,7 @@ bool PackOp::isLikePad() { auto packedTensorType = - (*this)->getResultTypes().front().cast(); + llvm::cast((*this)->getResultTypes().front()); return isLikePadUnPad(*this, packedTensorType); } @@ -3861,7 +3859,7 @@ }; SmallVector mixedSizes; - auto srcType = source.getType().cast(); + auto srcType = llvm::cast(source.getType()); for (auto i : llvm::seq(0, srcType.getRank() - innerTileSizes.size())) { if (srcType.isDynamicDim(i)) @@ -3944,7 +3942,7 @@ // If no operand comes from a tensor::CastOp and can be folded then fail. bool hasTensorCastOperand = llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) { - if (opOperand.get().isa()) + if (llvm::isa(opOperand.get())) return false; auto castOp = opOperand.get().getDefiningOp(); return castOp && canFoldIntoConsumerOp(castOp); @@ -3961,7 +3959,7 @@ bool fold = canFoldIntoConsumerOp(tensorCastOp); newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get()); if (op.isDpsInit(&opOperand) && - !newOperands.back().getType().isa()) + !llvm::isa(newOperands.back().getType())) newResultTypes.push_back(newOperands.back().getType()); } diff --git a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp @@ -24,8 +24,8 @@ auto castOp = cast(op); assert(value == castOp.getResult() && "invalid value"); - if (castOp.getResult().getType().isa() && - castOp.getSource().getType().isa()) { + if (llvm::isa(castOp.getResult().getType()) && + llvm::isa(castOp.getSource().getType())) { cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim); } } @@ -100,7 +100,8 @@ auto rankOp = cast(op); assert(value == rankOp.getResult() && "invalid value"); - auto tensorType = rankOp.getTensor().getType().dyn_cast(); + auto tensorType = + llvm::dyn_cast(rankOp.getTensor().getType()); if (!tensorType) return; cstr.bound(value) == tensorType.getRank(); diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -60,20 +60,20 @@ // type in case the input is an unranked tensor type. // Case 1: Casting an unranked tensor - if (castOp.getSource().getType().isa()) { + if (llvm::isa(castOp.getSource().getType())) { // When casting to a ranked tensor, we cannot infer any static offset or // strides from the source. Assume fully dynamic. return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace); } // Case 2: Casting to an unranked tensor type - if (castOp.getType().isa()) { + if (llvm::isa(castOp.getType())) { return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace); } // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not // change. - auto rankedResultType = castOp.getType().cast(); + auto rankedResultType = llvm::cast(castOp.getType()); return MemRefType::get( rankedResultType.getShape(), rankedResultType.getElementType(), maybeSrcBufferType->cast().getLayout(), memorySpace); @@ -158,7 +158,7 @@ if (failed(maybeBuffer)) return failure(); Value buffer = *maybeBuffer; - auto bufferType = buffer.getType().cast(); + auto bufferType = llvm::cast(buffer.getType()); if (tensorResultType.getRank() == 0) { // 0-d collapses must go through a different op builder. @@ -383,11 +383,11 @@ SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); SmallVector mixedSizes = extractSliceOp.getMixedSizes(); SmallVector mixedStrides = extractSliceOp.getMixedStrides(); - return memref::SubViewOp::inferRankReducedResultType( - extractSliceOp.getType().getShape(), - srcMemrefType->cast(), mixedOffsets, mixedSizes, - mixedStrides) - .cast(); + return llvm::cast( + memref::SubViewOp::inferRankReducedResultType( + extractSliceOp.getType().getShape(), + srcMemrefType->cast(), mixedOffsets, mixedSizes, + mixedStrides)); } }; @@ -459,7 +459,7 @@ auto fromElementsOp = cast(op); // Should the buffer be deallocated? bool dealloc = shouldDeallocateOpResult( - fromElementsOp.getResult().cast(), options); + llvm::cast(fromElementsOp.getResult()), options); // TODO: Implement memory space for this op. if (options.defaultMemorySpace != Attribute()) @@ -467,7 +467,7 @@ // Allocate a buffer for the result. Location loc = op->getLoc(); - auto tensorType = fromElementsOp.getType().cast(); + auto tensorType = llvm::cast(fromElementsOp.getType()); auto shape = tensorType.getShape(); // TODO: Create alloc_tensor ops during TensorCopyInsertion. FailureOr tensorAlloc = @@ -540,7 +540,7 @@ ValueRange dynamicSizes, Region &generateBody) { assert(generateBody.hasOneBlock() && "expected body with single block"); - auto tensorType = tensorDestination.getType().cast(); + auto tensorType = llvm::cast(tensorDestination.getType()); assert(generateBody.getNumArguments() == tensorType.getRank() && "rank mismatch"); @@ -579,7 +579,7 @@ auto generateOp = cast(op); // Should the buffer be deallocated? bool dealloc = shouldDeallocateOpResult( - generateOp.getResult().cast(), options); + llvm::cast(generateOp.getResult()), options); // TODO: Implement memory space for this op. if (options.defaultMemorySpace != Attribute()) @@ -800,12 +800,11 @@ return failure(); // Take a subview of the destination buffer. - auto dstMemrefType = dstMemref->getType().cast(); + auto dstMemrefType = llvm::cast(dstMemref->getType()); auto subviewMemRefType = - memref::SubViewOp::inferRankReducedResultType( + llvm::cast(memref::SubViewOp::inferRankReducedResultType( insertSliceOp.getSourceType().getShape(), dstMemrefType, - mixedOffsets, mixedSizes, mixedStrides) - .cast(); + mixedOffsets, mixedSizes, mixedStrides)); Value subView = rewriter.create( loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, mixedStrides); @@ -899,8 +898,8 @@ } // Should the buffer be deallocated? - bool dealloc = - shouldDeallocateOpResult(padOp.getResult().cast(), options); + bool dealloc = shouldDeallocateOpResult( + llvm::cast(padOp.getResult()), options); // Allocate a buffer for the padded result. FailureOr tensorAlloc = allocateTensorForShapedValue(rewriter, loc, padOp.getResult(), @@ -992,7 +991,7 @@ return failure(); auto resultMemRefType = getMemRefType( reshapeOp.getResult(), options, /*layout=*/{}, - srcBuffer->getType().cast().getMemorySpace()); + llvm::cast(srcBuffer->getType()).getMemorySpace()); replaceOpWithNewBufferizedOp( rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer); return success(); @@ -1039,14 +1038,13 @@ return failure(); // Take a subview of the destination buffer. - auto destBufferType = destBuffer->getType().cast(); + auto destBufferType = llvm::cast(destBuffer->getType()); auto subviewMemRefType = - memref::SubViewOp::inferRankReducedResultType( + llvm::cast(memref::SubViewOp::inferRankReducedResultType( parallelInsertSliceOp.getSourceType().getShape(), destBufferType, parallelInsertSliceOp.getMixedOffsets(), parallelInsertSliceOp.getMixedSizes(), - parallelInsertSliceOp.getMixedStrides()) - .cast(); + parallelInsertSliceOp.getMixedStrides())); Value subview = rewriter.create( parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer, parallelInsertSliceOp.getMixedOffsets(), diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -88,8 +88,8 @@ LogicalResult matchAndRewrite(tosa::ReshapeOp op, PatternRewriter &rewriter) const override { Value input = op.getInput1(); - ShapedType inputTy = input.getType().cast(); - ShapedType resultTy = op.getType().cast(); + ShapedType inputTy = llvm::cast(input.getType()); + ShapedType resultTy = llvm::cast(op.getType()); if (inputTy.getElementType() != resultTy.getElementType()) return rewriter.notifyMatchFailure(op, "element type does not match."); @@ -106,7 +106,7 @@ // Build new const op with correct output shape DenseElementsAttr outputAttr = inputAttr.reshape( - inputAttr.getType().cast().clone(op.getNewShape())); + llvm::cast(inputAttr.getType()).clone(op.getNewShape())); rewriter.replaceOpWithNewOp(op, resultTy, outputAttr); return success(); } @@ -198,7 +198,7 @@ } auto input = op.getInput1(); - auto inputTy = input.getType().cast(); + auto inputTy = llvm::cast(input.getType()); if (!inputTy.hasRank()) return rewriter.notifyMatchFailure(op, "Unranked input."); @@ -255,15 +255,15 @@ auto input = op.getInput1(); auto padding = op.getPadding(); - ShapedType inputTy = input.getType().cast(); + ShapedType inputTy = llvm::cast(input.getType()); Type elementTy = inputTy.getElementType(); Attribute constantAttr; - if (elementTy.isa()) { + if (llvm::isa(elementTy)) { constantAttr = rewriter.getFloatAttr(elementTy, 0.0); - } else if (elementTy.isa() && !op.getQuantizationInfo()) { + } else if (llvm::isa(elementTy) && !op.getQuantizationInfo()) { constantAttr = rewriter.getIntegerAttr(elementTy, 0); - } else if (elementTy.isa() && op.getQuantizationInfo()) { + } else if (llvm::isa(elementTy) && op.getQuantizationInfo()) { auto value = op.getQuantizationInfo()->getInputZp(); constantAttr = rewriter.getIntegerAttr(elementTy, value); } @@ -298,8 +298,8 @@ PatternRewriter &rewriter) const override { Value input = op.getInput(); Value output = op.getOutput(); - ShapedType inputType = input.getType().cast(); - ShapedType outputType = output.getType().cast(); + ShapedType inputType = llvm::cast(input.getType()); + ShapedType outputType = llvm::cast(output.getType()); if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) { return failure(); @@ -332,8 +332,7 @@ LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override { Value input = op.getInput(); - auto inputType = - op.getInput().getType().template dyn_cast(); + auto inputType = llvm::dyn_cast(op.getInput().getType()); auto inputElementType = inputType.getElementType(); if (!inputType.hasStaticShape()) { @@ -373,7 +372,7 @@ return failure(); } - if (inputElementType.isa()) { + if (llvm::isa(inputElementType)) { int64_t minClamp = op.getMinInt(); int64_t maxClamp = op.getMaxInt(); @@ -498,19 +497,19 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy) { if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) { - auto lETy = lhs.getType().cast().getElementType(); - auto rETy = rhs.getType().cast().getElementType(); + auto lETy = llvm::cast(lhs.getType()).getElementType(); + auto rETy = llvm::cast(rhs.getType()).getElementType(); if (lETy != rETy) return {}; - if (lETy.isa()) { + if (llvm::isa(lETy)) { APInt l = lhs.getSplatValue(); APInt r = rhs.getSplatValue(); auto result = IntFolder()(l, r); return DenseElementsAttr::get(returnTy, result); } - if (lETy.isa()) { + if (llvm::isa(lETy)) { APFloat l = lhs.getSplatValue(); APFloat r = rhs.getSplatValue(); auto result = FloatFolder()(l, r); @@ -522,18 +521,18 @@ } static bool isSplatZero(Type elemType, DenseElementsAttr val) { - if (elemType.isa()) + if (llvm::isa(elemType)) return val && val.isSplat() && val.getSplatValue().isZero(); - if (elemType.isa()) + if (llvm::isa(elemType)) return val && val.isSplat() && val.getSplatValue().isZero(); return false; } static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) { - if (elemType.isa()) + if (llvm::isa(elemType)) return val && val.isSplat() && val.getSplatValue().isExactlyValue(1.0); - if (elemType.isa()) { + if (llvm::isa(elemType)) { const int64_t shifted = 1LL << shift; return val && val.isSplat() && val.getSplatValue().getSExtValue() == shifted; @@ -542,9 +541,9 @@ } OpFoldResult AddOp::fold(FoldAdaptor adaptor) { - auto lhsTy = getInput1().getType().dyn_cast(); - auto rhsTy = getInput2().getType().dyn_cast(); - auto resultTy = getType().dyn_cast(); + auto lhsTy = llvm::dyn_cast(getInput1().getType()); + auto rhsTy = llvm::dyn_cast(getInput2().getType()); + auto resultTy = llvm::dyn_cast(getType()); if (!lhsTy || !rhsTy || !resultTy) return {}; @@ -565,9 +564,9 @@ } OpFoldResult DivOp::fold(FoldAdaptor adaptor) { - auto lhsTy = getInput1().getType().dyn_cast(); - auto rhsTy = getInput2().getType().dyn_cast(); - auto resultTy = getType().dyn_cast(); + auto lhsTy = llvm::dyn_cast(getInput1().getType()); + auto rhsTy = llvm::dyn_cast(getInput2().getType()); + auto resultTy = llvm::dyn_cast(getType()); if (!lhsTy || !rhsTy || !resultTy) return {}; if (lhsTy != rhsTy) @@ -577,17 +576,19 @@ auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); if (lhsAttr && lhsAttr.isSplat()) { - if (resultETy.isa() && lhsAttr.getSplatValue().isZero()) + if (llvm::isa(resultETy) && + lhsAttr.getSplatValue().isZero()) return lhsAttr; } if (rhsAttr && rhsAttr.isSplat()) { - if (resultETy.isa() && rhsAttr.getSplatValue().isOne()) + if (llvm::isa(resultETy) && + rhsAttr.getSplatValue().isOne()) return getInput1(); } if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) { - if (resultETy.isa()) { + if (llvm::isa(resultETy)) { APInt l = lhsAttr.getSplatValue(); APInt r = rhsAttr.getSplatValue(); APInt result = l.sdiv(r); @@ -602,7 +603,7 @@ DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType ty, int32_t shift) { if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) { - if (ty.getElementType().isa()) { + if (llvm::isa(ty.getElementType())) { APInt l = lhs.getSplatValue(); APInt r = rhs.getSplatValue(); @@ -619,7 +620,7 @@ return DenseElementsAttr::get(ty, result); } - if (ty.getElementType().isa()) { + if (llvm::isa(ty.getElementType())) { APFloat l = lhs.getSplatValue(); APFloat r = rhs.getSplatValue(); APFloat result = l * r; @@ -634,9 +635,9 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) { auto lhs = getInput1(); auto rhs = getInput2(); - auto lhsTy = lhs.getType().dyn_cast(); - auto rhsTy = rhs.getType().dyn_cast(); - auto resultTy = getType().dyn_cast(); + auto lhsTy = llvm::dyn_cast(lhs.getType()); + auto rhsTy = llvm::dyn_cast(rhs.getType()); + auto resultTy = llvm::dyn_cast(getType()); if (!lhsTy || !rhsTy || !resultTy) return {}; @@ -644,7 +645,7 @@ auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); - const int64_t shift = resultETy.isa() ? getShift() : 0; + const int64_t shift = llvm::isa(resultETy) ? getShift() : 0; if (rhsTy == resultTy) { if (isSplatZero(resultETy, lhsAttr)) return lhsAttr; @@ -662,9 +663,9 @@ } OpFoldResult SubOp::fold(FoldAdaptor adaptor) { - auto lhsTy = getInput1().getType().dyn_cast(); - auto rhsTy = getInput2().getType().dyn_cast(); - auto resultTy = getType().dyn_cast(); + auto lhsTy = llvm::dyn_cast(getInput1().getType()); + auto rhsTy = llvm::dyn_cast(getInput2().getType()); + auto resultTy = llvm::dyn_cast(getType()); if (!lhsTy || !rhsTy || !resultTy) return {}; @@ -711,7 +712,7 @@ } // namespace OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) { - auto resultTy = getType().dyn_cast(); + auto resultTy = llvm::dyn_cast(getType()); auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); @@ -723,7 +724,7 @@ } OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) { - auto resultTy = getType().dyn_cast(); + auto resultTy = llvm::dyn_cast(getType()); auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); @@ -736,16 +737,16 @@ } OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { - auto resultTy = getType().dyn_cast(); + auto resultTy = llvm::dyn_cast(getType()); auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); Value lhs = getInput1(); Value rhs = getInput2(); - auto lhsTy = lhs.getType().cast(); + auto lhsTy = llvm::cast(lhs.getType()); // If we are comparing an integer value to itself it is always true. We can // not do this with float due to float values. - if (lhsTy.getElementType().isa() && resultTy && + if (llvm::isa(lhsTy.getElementType()) && resultTy && resultTy.hasStaticShape() && lhs == rhs) { return DenseElementsAttr::get(resultTy, true); } @@ -766,41 +767,41 @@ if (!operand) return {}; - auto inTy = getInput().getType().cast(); - auto outTy = getType().cast(); + auto inTy = llvm::cast(getInput().getType()); + auto outTy = llvm::cast(getType()); auto inETy = inTy.getElementType(); auto outETy = outTy.getElementType(); if (operand.isSplat()) { - if (inETy.isa() && outETy.isa()) { + if (llvm::isa(inETy) && llvm::isa(outETy)) { bool overflow; auto splatVal = operand.getSplatValue(); - auto &semantics = outETy.cast().getFloatSemantics(); + auto &semantics = llvm::cast(outETy).getFloatSemantics(); splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven, &overflow); return SplatElementsAttr::get(outTy, splatVal); } - if (inETy.isa() && outETy.isa()) { - auto unsign = inETy.cast().isUnsignedInteger(); - APFloat splatVal(outETy.cast().getFloatSemantics()); + if (llvm::isa(inETy) && llvm::isa(outETy)) { + auto unsign = llvm::cast(inETy).isUnsignedInteger(); + APFloat splatVal(llvm::cast(outETy).getFloatSemantics()); splatVal.convertFromAPInt(operand.getSplatValue(), !unsign, llvm::RoundingMode::NearestTiesToEven); return SplatElementsAttr::get(outTy, splatVal); } - if (inETy.isa() && outETy.isa()) { - auto unsign = outETy.cast().isUnsignedInteger(); - auto intVal = - APSInt(outETy.cast().getIntOrFloatBitWidth(), unsign); + if (llvm::isa(inETy) && llvm::isa(outETy)) { + auto unsign = llvm::cast(outETy).isUnsignedInteger(); + auto intVal = APSInt( + llvm::cast(outETy).getIntOrFloatBitWidth(), unsign); auto floatVal = operand.getSplatValue(); bool exact; floatVal.convertToInteger(intVal, llvm::RoundingMode::TowardZero, &exact); return SplatElementsAttr::get(outTy, intVal); } - if (inETy.isa() && outETy.isa()) { - auto unsignIn = inETy.cast().isUnsignedInteger(); + if (llvm::isa(inETy) && llvm::isa(outETy)) { + auto unsignIn = llvm::cast(inETy).isUnsignedInteger(); bool trunc = inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth(); auto intVal = operand.getSplatValue(); @@ -842,8 +843,8 @@ #undef REDUCE_FOLDER OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { - auto inputTy = getInput1().getType().dyn_cast(); - auto outputTy = getType().dyn_cast(); + auto inputTy = llvm::dyn_cast(getInput1().getType()); + auto outputTy = llvm::dyn_cast(getType()); if (!inputTy || !outputTy) return {}; @@ -894,8 +895,8 @@ } auto input = getInput(); - auto inputTy = input.getType().cast(); - auto resultTy = getType().cast(); + auto inputTy = llvm::cast(input.getType()); + auto resultTy = llvm::cast(getType()); if (inputTy != resultTy) return {}; @@ -904,7 +905,7 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) { auto operand = getInput(); - auto operandTy = operand.getType().cast(); + auto operandTy = llvm::cast(operand.getType()); auto axis = getAxis(); auto operandAttr = adaptor.getInput().dyn_cast_or_null(); if (operandAttr) @@ -918,8 +919,8 @@ } OpFoldResult SliceOp::fold(FoldAdaptor adaptor) { - auto inputTy = getInput().getType().dyn_cast(); - auto outputTy = getType().dyn_cast(); + auto inputTy = llvm::dyn_cast(getInput().getType()); + auto outputTy = llvm::dyn_cast(getType()); if (!inputTy || !outputTy) return {}; @@ -972,8 +973,8 @@ } OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { - auto inputTy = getInput1().getType().cast(); - auto resultTy = getType().cast(); + auto inputTy = llvm::cast(getInput1().getType()); + auto resultTy = llvm::cast(getType()); // Transposing splat values just means reshaping. if (auto input = adaptor.getInput1().dyn_cast_or_null()) { diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -90,8 +90,9 @@ Type type, Location loc) { // Tosa dialect constants only support ElementsAttr unlike standard dialect // constant which supports all attributes. - if (value.isa()) - return builder.create(loc, type, value.cast()); + if (llvm::isa(value)) + return builder.create(loc, type, + llvm::cast(value)); return nullptr; } @@ -101,10 +102,8 @@ template static LogicalResult verifyConvOp(T op) { // All TOSA conv ops have an input() and weight(). - auto inputType = - op.getInput().getType().template dyn_cast(); - auto weightType = - op.getWeight().getType().template dyn_cast(); + auto inputType = llvm::dyn_cast(op.getInput().getType()); + auto weightType = llvm::dyn_cast(op.getWeight().getType()); // Must be ranked tensor types if (!inputType) { @@ -119,8 +118,8 @@ auto inputEType = inputType.getElementType(); auto weightEType = weightType.getElementType(); - bool inputIsQuant = !inputEType.template isa(); - bool weightIsQuant = !weightEType.template isa(); + bool inputIsQuant = !llvm::isa(inputEType); + bool weightIsQuant = !llvm::isa(weightEType); // Either both must be quantized or both unquantized. if (inputIsQuant != weightIsQuant) { @@ -143,13 +142,15 @@ } LogicalResult tosa::AvgPool2dOp::verify() { - auto inputETy = getInput().getType().cast().getElementType(); - auto resultETy = getType().cast().getElementType(); + auto inputETy = llvm::cast(getInput().getType()).getElementType(); + auto resultETy = llvm::cast(getType()).getElementType(); - if (auto quantType = inputETy.dyn_cast()) + if (auto quantType = + llvm::dyn_cast(inputETy)) inputETy = quantType.getStorageType(); - if (auto quantType = resultETy.dyn_cast()) + if (auto quantType = + llvm::dyn_cast(resultETy)) resultETy = quantType.getStorageType(); if (inputETy.isF32() && resultETy.isF32()) @@ -240,16 +241,16 @@ if (quantAttr) { result.addAttribute("quantization_info", quantAttr); - auto inputType = a.getType().dyn_cast(); + auto inputType = llvm::dyn_cast(a.getType()); assert(inputType && "Input must be a shaped tensor type!"); - auto inputQType = inputType.getElementType() - .dyn_cast(); + auto inputQType = llvm::dyn_cast( + inputType.getElementType()); assert(inputQType && "Tensor must have quantized datatype!"); unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); - auto outputShapedType = outputType.dyn_cast(); + auto outputShapedType = llvm::dyn_cast(outputType); assert(outputShapedType && "Output must be a shaped type"); IntegerType accElementType; @@ -368,7 +369,7 @@ OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape = operands.getShape(0); - IntegerAttr axis = attributes.get("axis").cast(); + IntegerAttr axis = llvm::cast(attributes.get("axis")); int32_t axisVal = axis.getValue().getSExtValue(); if (!inputShape.hasRank()) { @@ -432,7 +433,7 @@ SmallVectorImpl &inferredReturnShapes) { // Infer all dimension sizes by reducing based on inputs. int32_t axis = - attributes.get("axis").cast().getValue().getSExtValue(); + llvm::cast(attributes.get("axis")).getValue().getSExtValue(); llvm::SmallVector outputShape; bool hasRankedInput = false; for (auto operand : operands) { @@ -459,7 +460,8 @@ hasRankedInput = true; } - Type inputType = operands.getType()[0].cast().getElementType(); + Type inputType = + llvm::cast(operands.getType()[0]).getElementType(); if (!hasRankedInput) { inferredReturnShapes.push_back(ShapedTypeComponents(inputType)); return success(); @@ -738,8 +740,8 @@ } mlir::LogicalResult tosa::ReshapeOp::verify() { - ShapedType inputType = getInput1().getType().cast(); - ShapedType outputType = getType().cast(); + ShapedType inputType = llvm::cast(getInput1().getType()); + ShapedType outputType = llvm::cast(getType()); if (inputType.hasStaticShape() && outputType.hasStaticShape()) { int64_t inputElementsNum = inputType.getNumElements(); @@ -1064,9 +1066,11 @@ int64_t height = inputShape.getDimSize(1); int64_t width = inputShape.getDimSize(2); - ArrayRef kernel = attributes.get("kernel").cast(); - ArrayRef stride = attributes.get("stride").cast(); - ArrayRef pad = attributes.get("pad").cast(); + ArrayRef kernel = + llvm::cast(attributes.get("kernel")); + ArrayRef stride = + llvm::cast(attributes.get("stride")); + ArrayRef pad = llvm::cast(attributes.get("pad")); if (!ShapedType::isDynamic(height)) { int64_t padded = height + pad[0] + pad[1] - kernel[0]; @@ -1473,7 +1477,7 @@ } std::optional> ApplyScaleOp::getShapeForUnroll() { - if (auto vt = getType().dyn_cast()) + if (auto vt = llvm::dyn_cast(getType())) return llvm::to_vector<4>(vt.getShape()); return std::nullopt; } diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -169,7 +169,7 @@ return success(); } if (attribute.getName().getValue() == kTargetTagAttrName) { - if (!attribute.getValue().isa()) { + if (!llvm::isa(attribute.getValue())) { return op->emitError() << attribute.getName() << " attribute must be a string"; } @@ -177,7 +177,7 @@ } if (attribute.getName().getValue() == kArgConsumedAttrName || attribute.getName().getValue() == kArgReadOnlyAttrName) { - if (!attribute.getValue().isa()) { + if (!llvm::isa(attribute.getValue())) { return op->emitError() << attribute.getName() << " must be a unit attribute"; } diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -114,7 +114,7 @@ function_ref)> operationsFn, function_ref)> paramsFn, function_ref valuesFn) { - if (handle.getType().isa()) { + if (llvm::isa(handle.getType())) { SmallVector operations; operations.reserve(values.size()); for (transform::MappedValue value : values) { @@ -130,7 +130,8 @@ return DiagnosedSilenceableFailure::success(); } - if (handle.getType().isa()) { + if (llvm::isa( + handle.getType())) { SmallVector payloadValues; payloadValues.reserve(values.size()); for (transform::MappedValue value : values) { @@ -146,7 +147,7 @@ return DiagnosedSilenceableFailure::success(); } - assert(handle.getType().isa() && + assert(llvm::isa(handle.getType()) && "unsupported kind of block argument"); SmallVector parameters; parameters.reserve(values.size()); @@ -185,7 +186,7 @@ ArrayRef targets) { assert(value != kTopLevelValue && "attempting to reset the transformation root"); - assert(value.getType().isa() && + assert(llvm::isa(value.getType()) && "wrong handle type"); for (Operation *target : targets) { @@ -195,7 +196,7 @@ << "attempting to assign a null payload op to this transform value"; } - auto iface = value.getType().cast(); + auto iface = llvm::cast(value.getType()); DiagnosedSilenceableFailure result = iface.checkPayload(value.getLoc(), targets); if (failed(result.checkAndReport())) @@ -220,7 +221,7 @@ transform::TransformState::setPayloadValues(Value handle, ValueRange payloadValues) { assert(handle != nullptr && "attempting to set params for a null value"); - assert(handle.getType().isa() && + assert(llvm::isa(handle.getType()) && "wrong handle type"); for (Value payload : payloadValues) { @@ -230,7 +231,7 @@ "value to this transform handle"; } - auto iface = handle.getType().cast(); + auto iface = llvm::cast(handle.getType()); SmallVector payloadValueVector = llvm::to_vector(payloadValues); DiagnosedSilenceableFailure result = iface.checkPayload(handle.getLoc(), payloadValueVector); @@ -262,7 +263,7 @@ << "attempting to assign a null parameter to this transform value"; } - auto valueType = value.getType().dyn_cast(); + auto valueType = llvm::dyn_cast(value.getType()); assert(value && "cannot associate parameter with a value of non-parameter type"); DiagnosedSilenceableFailure result = @@ -497,11 +498,11 @@ Operation *definingOp; std::optional resultNo; unsigned argumentNo, blockNo, regionNo; - if (auto opResult = payloadValue.dyn_cast()) { + if (auto opResult = llvm::dyn_cast(payloadValue)) { definingOp = opResult.getOwner(); resultNo = opResult.getResultNumber(); } else { - auto arg = payloadValue.cast(); + auto arg = llvm::cast(payloadValue); definingOp = arg.getParentBlock()->getParentOp(); argumentNo = arg.getArgNumber(); blockNo = std::distance(arg.getOwner()->getParent()->begin(), @@ -602,11 +603,11 @@ }; } - if (auto opResult = payloadValue.dyn_cast()) { + if (auto opResult = llvm::dyn_cast(payloadValue)) { Operation *payloadOp = opResult.getOwner(); recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue); } else { - auto arg = payloadValue.dyn_cast(); + auto arg = llvm::dyn_cast(payloadValue); for (Operation &payloadOp : *arg.getOwner()) recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue); } @@ -642,13 +643,12 @@ }; if (llvm::any_of(effects, consumesTarget)) { FULL_LDBG("----found consume effect -> SKIP\n"); - if (target.get().getType().isa()) { + if (llvm::isa(target.get().getType())) { FULL_LDBG("----recordOpHandleInvalidation\n"); ArrayRef payloadOps = getPayloadOps(target.get()); recordOpHandleInvalidation(target, payloadOps); - } else if (target.get() - .getType() - .isa()) { + } else if (llvm::isa( + target.get().getType())) { FULL_LDBG("----recordValueHandleInvalidation\n"); recordValueHandleInvalidation(target); } else { @@ -717,7 +717,7 @@ FULL_LDBG("--handle is consumed\n"); Type operandType = operand.get().getType(); - if (operandType.isa()) { + if (llvm::isa(operandType)) { FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n"); DiagnosedSilenceableFailure check = checkRepeatedConsumptionInOperand( @@ -727,7 +727,7 @@ FULL_LDBG("----FAILED\n"); return check; } - } else if (operandType.isa()) { + } else if (llvm::isa(operandType)) { FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n"); DiagnosedSilenceableFailure check = checkRepeatedConsumptionInOperand( @@ -794,7 +794,7 @@ #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS for (unsigned index : consumedOperands) { Value operand = transform->getOperand(index); - if (operand.getType().isa()) { + if (llvm::isa(operand.getType())) { for (Operation *payloadOp : getPayloadOps(operand)) { llvm::append_range(origOpFlatResults, payloadOp->getResults()); #if LLVM_ENABLE_ABI_BREAKING_CHECKS @@ -808,15 +808,15 @@ } continue; } - if (operand.getType().isa()) { + if (llvm::isa(operand.getType())) { for (Value payloadValue : getPayloadValues(operand)) { - if (payloadValue.isa()) { + if (llvm::isa(payloadValue)) { origAssociatedOps.push_back(payloadValue.getDefiningOp()); continue; } llvm::append_range( origAssociatedOps, - llvm::map_range(*payloadValue.cast().getOwner(), + llvm::map_range(*llvm::cast(payloadValue).getOwner(), [](Operation &op) { return &op; })); } continue; @@ -847,9 +847,10 @@ // allows us to catch use-after-free with assertions later on. for (unsigned index : consumedOperands) { Value operand = transform->getOperand(index); - if (operand.getType().isa()) { + if (llvm::isa(operand.getType())) { forgetMapping(operand, origOpFlatResults); - } else if (operand.getType().isa()) { + } else if (llvm::isa( + operand.getType())) { forgetValueMapping(operand, origAssociatedOps); } } @@ -923,14 +924,14 @@ LogicalResult transform::TransformState::updateStateFromResults( const TransformResults &results, ResultRange opResults) { for (OpResult result : opResults) { - if (result.getType().isa()) { + if (llvm::isa(result.getType())) { assert(results.isParam(result.getResultNumber()) && "expected parameters for the parameter-typed result"); if (failed( setParams(result, results.getParams(result.getResultNumber())))) { return failure(); } - } else if (result.getType().isa()) { + } else if (llvm::isa(result.getType())) { assert(results.isValue(result.getResultNumber()) && "expected values for value-type-result"); if (failed(setPayloadValues( @@ -1137,19 +1138,19 @@ llvm::zip(partialResult, transformOp->getResults())) { if (ptr.isNull()) continue; - if (res.getType().template isa() && + if (llvm::isa(res.getType()) && !ptr.is()) { return emitDiag() << "application of " << transformOpName << " expected to produce an Operation * for result #" << res.getResultNumber(); } - if (res.getType().template isa() && + if (llvm::isa(res.getType()) && !ptr.is()) { return emitDiag() << "application of " << transformOpName << " expected to produce an Attribute for result #" << res.getResultNumber(); } - if (res.getType().template isa() && + if (llvm::isa(res.getType()) && !ptr.is()) { return emitDiag() << "application of " << transformOpName << " expected to produce a Value for result #" @@ -1182,10 +1183,10 @@ for (OpResult r : transformOp->getResults()) { unsigned position = r.getResultNumber(); - if (r.getType().isa()) { + if (llvm::isa(r.getType())) { transformResults.setParams(r, castVector(transposed[position])); - } else if (r.getType().isa()) { + } else if (llvm::isa(r.getType())) { transformResults.setValues(r, castVector(transposed[position])); } else { transformResults.set(r, castVector(transposed[position])); @@ -1202,12 +1203,13 @@ ValueRange values, const transform::TransformState &state) { for (Value operand : values) { SmallVector &mapped = mappings.emplace_back(); - if (operand.getType().isa()) { + if (llvm::isa(operand.getType())) { llvm::append_range(mapped, state.getPayloadOps(operand)); - } else if (operand.getType().isa()) { + } else if (llvm::isa( + operand.getType())) { llvm::append_range(mapped, state.getPayloadValues(operand)); } else { - assert(operand.getType().isa() && + assert(llvm::isa(operand.getType()) && "unsupported kind of transform dialect value"); llvm::append_range(mapped, state.getParams(operand)); } @@ -1220,14 +1222,15 @@ for (auto &&[terminatorOperand, result] : llvm::zip(block->getTerminator()->getOperands(), block->getParentOp()->getOpResults())) { - if (result.getType().isa()) { + if (llvm::isa(result.getType())) { results.set(result, state.getPayloadOps(terminatorOperand)); - } else if (result.getType() - .isa()) { + } else if (llvm::isa( + result.getType())) { results.setValues(result, state.getPayloadValues(terminatorOperand)); } else { - assert(result.getType().isa() && - "unhandled transform type interface"); + assert( + llvm::isa(result.getType()) && + "unhandled transform type interface"); results.setParams(result, state.getParams(terminatorOperand)); } } @@ -1291,7 +1294,8 @@ return op->emitOpError() << "expects the entry block to have at least one argument"; } - if (!body->getArgument(0).getType().isa()) { + if (!llvm::isa( + body->getArgument(0).getType())) { return op->emitOpError() << "expects the first entry block argument to be of type " "implementing TransformHandleTypeInterface"; @@ -1305,9 +1309,8 @@ } } for (BlockArgument arg : body->getArguments().drop_front()) { - if (arg.getType() - .isa()) + if (llvm::isa(arg.getType())) continue; InFlightDiagnostic diag = @@ -1344,9 +1347,8 @@ bool hasPayloadOperands = false; for (Value operand : op->getOperands()) { onlyReadsHandle(operand, effects); - if (operand.getType() - .isa()) + if (llvm::isa(operand.getType())) hasPayloadOperands = true; } if (hasPayloadOperands) @@ -1364,7 +1366,7 @@ op->getName().getStringRef()); } for (Value result : op->getResults()) { - if (result.getType().isa()) + if (llvm::isa(result.getType())) continue; return op->emitOpError() << "ParamProducerTransformOpTrait attached to this op expects " diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -440,8 +440,8 @@ return llvm::all_of( std::initializer_list{inputs.front(), outputs.front()}, [](Type ty) { - return ty - .isa(); + return llvm::isa(ty); }); } @@ -563,7 +563,8 @@ // the payload to the result. Note that we need to consume the root handle to // make sure any handles to operations inside, that could have been affected // by actions, are invalidated. - results.set(getUpdated().cast(), state.getPayloadOps(getRoot())); + results.set(llvm::cast(getUpdated()), + state.getPayloadOps(getRoot())); return DiagnosedSilenceableFailure::success(); } @@ -810,7 +811,7 @@ } for (unsigned i = 0; i < getNumResults(); ++i) - results.set(getResult(i).cast(), resultOps[i]); + results.set(llvm::cast(getResult(i)), resultOps[i]); return DiagnosedSilenceableFailure::success(); } @@ -863,7 +864,7 @@ return emitOpError() << "expects the same number of results as the " "terminator has operands"; for (Value v : yieldOp.getOperands()) - if (!v.getType().isa()) + if (!llvm::isa(v.getType())) return yieldOp->emitOpError("expects operands to have types implementing " "TransformHandleTypeInterface"); return success(); @@ -888,7 +889,7 @@ } parents.insert(parent); } - results.set(getResult().cast(), parents.getArrayRef()); + results.set(llvm::cast(getResult()), parents.getArrayRef()); return DiagnosedSilenceableFailure::success(); } @@ -902,7 +903,7 @@ int64_t resultNumber = getResultNumber(); ArrayRef payloadOps = state.getPayloadOps(getTarget()); if (payloadOps.empty()) { - results.set(getResult().cast(), {}); + results.set(llvm::cast(getResult()), {}); return DiagnosedSilenceableFailure::success(); } if (payloadOps.size() != 1) @@ -912,7 +913,7 @@ Operation *target = payloadOps.front(); if (target->getNumResults() <= resultNumber) return emitDefiniteFailure() << "result number overflow"; - results.set(getResult().cast(), + results.set(llvm::cast(getResult()), llvm::to_vector(target->getResult(resultNumber).getUsers())); return DiagnosedSilenceableFailure::success(); } @@ -926,7 +927,7 @@ transform::TransformState &state) { SmallVector definingOps; for (Value v : state.getPayloadValues(getTarget())) { - if (v.isa()) { + if (llvm::isa(v)) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "cannot get defining op of block argument"; diag.attachNote(v.getLoc()) << "target value"; @@ -934,7 +935,7 @@ } definingOps.push_back(v.getDefiningOp()); } - results.set(getResult().cast(), definingOps); + results.set(llvm::cast(getResult()), definingOps); return DiagnosedSilenceableFailure::success(); } @@ -962,7 +963,7 @@ } producers.push_back(producer); } - results.set(getResult().cast(), producers); + results.set(llvm::cast(getResult()), producers); return DiagnosedSilenceableFailure::success(); } @@ -984,7 +985,7 @@ } opResults.push_back(target->getOpResult(resultNumber)); } - results.setValues(getResult().cast(), opResults); + results.setValues(llvm::cast(getResult()), opResults); return DiagnosedSilenceableFailure::success(); } @@ -1211,8 +1212,8 @@ } for (auto &&[i, param, reference] : llvm::enumerate(params, references)) { - auto intAttr = param.dyn_cast(); - auto refAttr = reference.dyn_cast(); + auto intAttr = llvm::dyn_cast(param); + auto refAttr = llvm::dyn_cast(reference); if (!intAttr || !refAttr) { return emitDefiniteFailure() << "non-integer parameter value not expected"; @@ -1295,12 +1296,12 @@ for (Value operand : getHandles()) llvm::append_range(operations, state.getPayloadOps(operand)); if (!getDeduplicate()) { - results.set(getResult().cast(), operations); + results.set(llvm::cast(getResult()), operations); return DiagnosedSilenceableFailure::success(); } SetVector uniqued(operations.begin(), operations.end()); - results.set(getResult().cast(), uniqued.getArrayRef()); + results.set(llvm::cast(getResult()), uniqued.getArrayRef()); return DiagnosedSilenceableFailure::success(); } @@ -1535,7 +1536,7 @@ // Set transform op results. for (auto &&it : llvm::enumerate(resultHandles)) - results.set(getResult(it.index()).cast(), it.value()); + results.set(llvm::cast(getResult(it.index())), it.value()); return DiagnosedSilenceableFailure::success(); } @@ -1573,7 +1574,7 @@ << "could not find pattern '" << getPatternName() << "'"; } } - results.set(getResult().cast(), targets); + results.set(llvm::cast(getResult()), targets); return DiagnosedSilenceableFailure::success(); } @@ -1594,22 +1595,23 @@ unsigned numRepetitions = state.getPayloadOps(getPattern()).size(); for (const auto &en : llvm::enumerate(getHandles())) { Value handle = en.value(); - if (handle.getType().isa()) { + if (llvm::isa(handle.getType())) { ArrayRef current = state.getPayloadOps(handle); SmallVector payload; payload.reserve(numRepetitions * current.size()); for (unsigned i = 0; i < numRepetitions; ++i) llvm::append_range(payload, current); - results.set(getReplicated()[en.index()].cast(), payload); + results.set(llvm::cast(getReplicated()[en.index()]), payload); } else { - assert(handle.getType().isa() && + assert(llvm::isa(handle.getType()) && "expected param type"); ArrayRef current = state.getParams(handle); SmallVector params; params.reserve(numRepetitions * current.size()); for (unsigned i = 0; i < numRepetitions; ++i) llvm::append_range(params, current); - results.setParams(getReplicated()[en.index()].cast(), params); + results.setParams(llvm::cast(getReplicated()[en.index()]), + params); } } return DiagnosedSilenceableFailure::success(); diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp @@ -75,7 +75,7 @@ LogicalResult transform::ParamType::verify(function_ref emitError, Type type) { - IntegerType intType = type.dyn_cast(); + IntegerType intType = llvm::dyn_cast(type); if (!intType || intType.getWidth() > 64) return emitError() << "only supports integer types with width <=64"; return success(); @@ -85,7 +85,7 @@ transform::ParamType::checkPayload(Location loc, ArrayRef payload) const { for (Attribute attr : payload) { - auto integerAttr = attr.dyn_cast(); + auto integerAttr = llvm::dyn_cast(attr); if (!integerAttr) { return emitSilenceableError(loc) << "expected parameter to be an integer attribute, got " << attr; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -65,7 +65,7 @@ // Inspect constant dense values. We count up for bits that // are set, count down for bits that are cleared, and bail // when a mix is detected. - if (auto denseElts = c.getValue().dyn_cast()) { + if (auto denseElts = llvm::dyn_cast(c.getValue())) { int64_t val = 0; for (bool b : denseElts.getValues()) if (b && val >= 0) @@ -88,7 +88,7 @@ bool allTrue = true; bool allFalse = true; for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) { - int64_t i = maskIdx.cast().getInt(); + int64_t i = llvm::cast(maskIdx).getInt(); if (i < dimSize) allTrue = false; if (i > 0) @@ -125,7 +125,7 @@ return elementType.isIntOrIndex(); case CombiningKind::MINF: case CombiningKind::MAXF: - return elementType.isa(); + return llvm::isa(elementType); } return false; } @@ -143,7 +143,7 @@ VectorType vectorType) { int64_t elementVectorRank = 0; VectorType elementVectorType = - shapedType.getElementType().dyn_cast(); + llvm::dyn_cast(shapedType.getElementType()); if (elementVectorType) elementVectorRank += elementVectorType.getRank(); // 0-d transfers are to/from tensor/memref and vector<1xt>. @@ -190,15 +190,15 @@ if (i < rankOffset) { // For leading dimensions, if we can prove that index are different we // know we are accessing disjoint slices. - if (indexA.getValue().cast().getInt() != - indexB.getValue().cast().getInt()) + if (llvm::cast(indexA.getValue()).getInt() != + llvm::cast(indexB.getValue()).getInt()) return true; } else { // For this dimension, we slice a part of the memref we need to make sure // the intervals accessed don't overlap. int64_t distance = - std::abs(indexA.getValue().cast().getInt() - - indexB.getValue().cast().getInt()); + std::abs(llvm::cast(indexA.getValue()).getInt() - + llvm::cast(indexB.getValue()).getInt()); if (distance >= transferA.getVectorType().getDimSize(i - rankOffset)) return true; } @@ -325,7 +325,7 @@ Type inferredReturnType; for (auto it : llvm::enumerate(getSourceVectorType().getShape())) if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) { - return attr.cast().getValue() == it.index(); + return llvm::cast(attr).getValue() == it.index(); })) targetShape.push_back(it.value()); // TODO: update to also allow 0-d vectors when available. @@ -426,8 +426,9 @@ void vector::ReductionOp::build(OpBuilder &builder, OperationState &result, CombiningKind kind, Value vector, Value acc) { - build(builder, result, vector.getType().cast().getElementType(), - kind, vector, acc); + build(builder, result, + llvm::cast(vector.getType()).getElementType(), kind, vector, + acc); } LogicalResult ReductionOp::verify() { @@ -659,9 +660,8 @@ // because tests still use the old format when 'iterator_types' attribute is // represented as an array of strings. // TODO: Remove this conversion once tests are fixed. - ArrayAttr iteratorTypes = - result.attributes.get(getIteratorTypesAttrName(result.name)) - .cast(); + ArrayAttr iteratorTypes = llvm::cast( + result.attributes.get(getIteratorTypesAttrName(result.name))); SmallVector iteratorTypeAttrs; @@ -687,8 +687,8 @@ if (masksInfo.size() != 2) return parser.emitError(parser.getNameLoc(), "expected zero or exactly 2 vector mask operands"); - auto lhsType = types[0].cast(); - auto rhsType = types[1].cast(); + auto lhsType = llvm::cast(types[0]); + auto rhsType = llvm::cast(types[1]); auto maskElementType = parser.getBuilder().getI1Type(); std::array maskTypes = { VectorType::Builder(lhsType).setElementType(maskElementType), @@ -707,8 +707,7 @@ for (auto attr : (*this)->getAttrs()) { if (attr.getName() == getIteratorTypesAttrName()) { auto iteratorTypes = - attr.getValue() - .cast() + llvm::cast(attr.getValue()) .getAsValueRange(); // Convert IteratorType enums into the string representation. This is // needed, because tests still use the old format when 'iterator_types' @@ -778,12 +777,12 @@ // Verify 'expectedResultDims'. if (expectedResultDims.empty()) { // No batch or free dimension implies a scalar result. - if (resType.isa() || accType.isa()) + if (llvm::isa(resType) || llvm::isa(accType)) return op.emitOpError("invalid accumulator/result vector shape"); } else { // At least one batch or free dimension implies a vector result. - auto resVectorType = resType.dyn_cast(); - auto accVectorType = accType.dyn_cast(); + auto resVectorType = llvm::dyn_cast(resType); + auto accVectorType = llvm::dyn_cast(accType); if (!resVectorType || !accVectorType) return op.emitOpError("invalid accumulator/result vector shape"); @@ -841,7 +840,7 @@ Type accType = getAccType(); Type resType = getResultType(); - if (lhsType.getElementType().isa()) { + if (llvm::isa(lhsType.getElementType())) { if (!lhsType.getElementType().isSignlessInteger()) return emitOpError("only supports signless integer types"); } @@ -860,7 +859,7 @@ if (map.getNumSymbols() != 0) return emitOpError("expected indexing map ") << index << " to have no symbols"; - auto vectorType = getOperand(index).getType().dyn_cast(); + auto vectorType = llvm::dyn_cast(getOperand(index).getType()); unsigned rank = vectorType ? vectorType.getShape().size() : 0; // Verify that the map has the right number of inputs, outputs, and indices. // This also correctly accounts for (..) -> () for rank-0 results. @@ -896,7 +895,7 @@ return failure(); // Verify supported combining kind. - auto vectorType = resType.dyn_cast(); + auto vectorType = llvm::dyn_cast(resType); auto elementType = vectorType ? vectorType.getElementType() : resType; if (!isSupportedCombiningKind(getKind(), elementType)) return emitOpError("unsupported contraction type"); @@ -949,7 +948,7 @@ IteratorType targetIteratorType, MLIRContext *context) { std::vector> dimMap; for (const auto &it : llvm::enumerate(iteratorTypes)) { - auto iteratorType = it.value().cast().getValue(); + auto iteratorType = llvm::cast(it.value()).getValue(); if (iteratorType != targetIteratorType) continue; // Search lhs/rhs map results for 'targetExpr'. @@ -965,13 +964,13 @@ void ContractionOp::getIterationBounds( SmallVectorImpl &iterationBounds) { auto lhsShape = getLhsType().getShape(); - auto resVectorType = getResultType().dyn_cast(); + auto resVectorType = llvm::dyn_cast(getResultType()); SmallVector indexingMaps(getIndexingMapsArray()); SmallVector iterationShape; for (const auto &it : llvm::enumerate(getIteratorTypes())) { // Search lhs/rhs map results for 'targetExpr'. auto targetExpr = getAffineDimExpr(it.index(), getContext()); - auto iteratorType = it.value().cast().getValue(); + auto iteratorType = llvm::cast(it.value()).getValue(); if (iteratorType == IteratorType::reduction) { // Get reduction dim size from lhs shape (same size in rhsShape). int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr); @@ -1085,7 +1084,7 @@ void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, Value source) { result.addOperands({source}); - result.addTypes(source.getType().cast().getElementType()); + result.addTypes(llvm::cast(source.getType()).getElementType()); } LogicalResult vector::ExtractElementOp::verify() { @@ -1116,15 +1115,15 @@ // Fold extractelement(broadcast(X)) -> X. if (auto broadcast = getVector().getDefiningOp()) - if (!broadcast.getSource().getType().isa()) + if (!llvm::isa(broadcast.getSource().getType())) return broadcast.getSource(); if (!pos || !src) return {}; - auto srcElements = src.cast().getValues(); + auto srcElements = llvm::cast(src).getValues(); - auto attr = pos.dyn_cast(); + auto attr = llvm::dyn_cast(pos); uint64_t posIdx = attr.getInt(); return srcElements[posIdx]; @@ -1155,7 +1154,7 @@ OpaqueProperties properties, RegionRange, SmallVectorImpl &inferredReturnTypes) { ExtractOp::Adaptor op(operands, attributes); - auto vectorType = op.getVector().getType().cast(); + auto vectorType = llvm::cast(op.getVector().getType()); if (static_cast(op.getPosition().size()) == vectorType.getRank()) { inferredReturnTypes.push_back(vectorType.getElementType()); } else { @@ -1170,7 +1169,7 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { // Allow extracting 1-element vectors instead of scalars. auto isCompatible = [](TypeRange l, TypeRange r) { - auto vectorType = l.front().dyn_cast(); + auto vectorType = llvm::dyn_cast(l.front()); return vectorType && vectorType.getShape().equals({1}) && vectorType.getElementType() == r.front(); }; @@ -1187,7 +1186,7 @@ return emitOpError( "expected position attribute of rank smaller than vector rank"); for (const auto &en : llvm::enumerate(positionAttr)) { - auto attr = en.value().dyn_cast(); + auto attr = llvm::dyn_cast(en.value()); if (!attr || attr.getInt() < 0 || attr.getInt() >= getSourceVectorType().getDimSize(en.index())) return emitOpError("expected position attribute #") @@ -1451,7 +1450,8 @@ if (extractOp.getType() == source.getType()) return source; auto getRank = [](Type type) { - return type.isa() ? type.cast().getRank() : 0; + return llvm::isa(type) ? llvm::cast(type).getRank() + : 0; }; // If splat or broadcast from a scalar, just return the source scalar. unsigned broadcastSrcRank = getRank(source.getType()); @@ -1462,8 +1462,8 @@ if (extractResultRank >= broadcastSrcRank) return Value(); // Check that the dimension of the result haven't been broadcasted. - auto extractVecType = extractOp.getType().dyn_cast(); - auto broadcastVecType = source.getType().dyn_cast(); + auto extractVecType = llvm::dyn_cast(extractOp.getType()); + auto broadcastVecType = llvm::dyn_cast(source.getType()); if (extractVecType && broadcastVecType && extractVecType.getShape() != broadcastVecType.getShape().take_back(extractResultRank)) @@ -1502,13 +1502,14 @@ return type.getShape().take_back(n + 1).front(); }; int64_t destinationRank = - extractOp.getType().isa() - ? extractOp.getType().cast().getRank() + llvm::isa(extractOp.getType()) + ? llvm::cast(extractOp.getType()).getRank() : 0; if (destinationRank > shapeCastOp.getSourceVectorType().getRank()) return Value(); if (destinationRank > 0) { - auto destinationType = extractOp.getResult().getType().cast(); + auto destinationType = + llvm::cast(extractOp.getResult().getType()); for (int64_t i = 0; i < destinationRank; i++) { // The lowest dimension of of the destination must match the lowest // dimension of the shapecast op source. @@ -1574,7 +1575,7 @@ sliceOffsets.pop_back(); } unsigned destinationRank = 0; - if (auto vecType = extractOp.getType().dyn_cast()) + if (auto vecType = llvm::dyn_cast(extractOp.getType())) destinationRank = vecType.getRank(); // The dimensions of the result need to be untouched by the // extractStridedSlice op. @@ -1595,8 +1596,8 @@ /// Fold extract_op fed from a chain of insertStridedSlice ops. static Value foldExtractStridedOpFromInsertChain(ExtractOp op) { - int64_t destinationRank = op.getType().isa() - ? op.getType().cast().getRank() + int64_t destinationRank = llvm::isa(op.getType()) + ? llvm::cast(op.getType()).getRank() : 0; auto insertOp = op.getVector().getDefiningOp(); while (insertOp) { @@ -1608,7 +1609,7 @@ auto extractOffsets = extractVector(op.getPosition()); if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) { - return attr.cast().getInt() != 1; + return llvm::cast(attr).getInt() != 1; })) return Value(); bool disjoint = false; @@ -1691,7 +1692,9 @@ if (extractOp.getType() == source.getType()) return failure(); auto getRank = [](Type type) { - return type.isa() ? type.cast().getRank() : 0; + return llvm::isa(type) + ? llvm::cast(type).getRank() + : 0; }; unsigned broadcastSrcRank = getRank(source.getType()); unsigned extractResultRank = getRank(extractOp.getType()); @@ -1703,7 +1706,7 @@ // Special case if broadcast src is a 0D vector. if (extractResultRank == 0) { - assert(broadcastSrcRank == 0 && source.getType().isa()); + assert(broadcastSrcRank == 0 && llvm::isa(source.getType())); rewriter.replaceOpWithNewOp(extractOp, source); return success(); } @@ -1726,11 +1729,11 @@ Attribute vectorCst; if (!matchPattern(sourceVector, m_Constant(&vectorCst))) return failure(); - auto splat = vectorCst.dyn_cast(); + auto splat = llvm::dyn_cast(vectorCst); if (!splat) return failure(); TypedAttr newAttr = splat.getSplatValue(); - if (auto vecDstType = extractOp.getType().dyn_cast()) + if (auto vecDstType = llvm::dyn_cast(extractOp.getType())) newAttr = DenseElementsAttr::get(vecDstType, newAttr); rewriter.replaceOpWithNewOp(extractOp, newAttr); return success(); @@ -1752,12 +1755,12 @@ if (!matchPattern(sourceVector, m_Constant(&vectorCst))) return failure(); - auto vecTy = sourceVector.getType().cast(); + auto vecTy = llvm::cast(sourceVector.getType()); if (vecTy.isScalable()) return failure(); // The splat case is handled by `ExtractOpSplatConstantFolder`. - auto dense = vectorCst.dyn_cast(); + auto dense = llvm::dyn_cast(vectorCst); if (!dense || dense.isSplat()) return failure(); @@ -1770,7 +1773,7 @@ auto denseValuesBegin = dense.value_begin() + elemBeginPosition; TypedAttr newAttr; - if (auto resVecTy = extractOp.getType().dyn_cast()) { + if (auto resVecTy = llvm::dyn_cast(extractOp.getType())) { SmallVector elementValues( denseValuesBegin, denseValuesBegin + resVecTy.getNumElements()); newAttr = DenseElementsAttr::get(resVecTy, elementValues); @@ -1794,7 +1797,7 @@ static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl &results) { for (auto attr : arrayAttr) - results.push_back(attr.cast().getInt()); + results.push_back(llvm::cast(attr).getInt()); } //===----------------------------------------------------------------------===// @@ -1830,7 +1833,7 @@ llvm::SetVector BroadcastOp::computeBroadcastedUnitDims() { // Scalar broadcast is without any unit dim broadcast. - auto srcVectorType = getSourceType().dyn_cast(); + auto srcVectorType = llvm::dyn_cast(getSourceType()); if (!srcVectorType) return {}; return ::computeBroadcastedUnitDims(srcVectorType.getShape(), @@ -1867,7 +1870,7 @@ Location loc = value.getLoc(); Type elementType = getElementTypeOrSelf(value.getType()); - VectorType srcVectorType = value.getType().dyn_cast(); + VectorType srcVectorType = llvm::dyn_cast(value.getType()); VectorType dstVectorType = VectorType::get(dstShape, elementType); // Step 2. If scalar -> dstShape broadcast, just do it. @@ -1952,7 +1955,7 @@ getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType)) return BroadcastableToResult::Success; // From now on, only vectors broadcast. - VectorType srcVectorType = srcType.dyn_cast(); + VectorType srcVectorType = llvm::dyn_cast(srcType); if (!srcVectorType) return BroadcastableToResult::SourceTypeNotAVector; @@ -2074,7 +2077,7 @@ int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) + (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0)); for (const auto &en : llvm::enumerate(maskAttr)) { - auto attr = en.value().dyn_cast(); + auto attr = llvm::dyn_cast(en.value()); if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize) return emitOpError("mask index #") << (en.index() + 1) << " out of range"; } @@ -2087,7 +2090,7 @@ OpaqueProperties properties, RegionRange, SmallVectorImpl &inferredReturnTypes) { ShuffleOp::Adaptor op(operands, attributes); - auto v1Type = op.getV1().getType().cast(); + auto v1Type = llvm::cast(op.getV1().getType()); auto v1Rank = v1Type.getRank(); // Construct resulting type: leading dimension matches mask // length, all trailing dimensions match the operands. @@ -2132,7 +2135,8 @@ if (!lhs || !rhs) return {}; - auto lhsType = lhs.cast().getType().cast(); + auto lhsType = + llvm::cast(lhs.cast().getType()); // Only support 1-D for now to avoid complicated n-D DenseElementsAttr // manipulation. if (lhsType.getRank() != 1) @@ -2140,8 +2144,8 @@ int64_t lhsSize = lhsType.getDimSize(0); SmallVector results; - auto lhsElements = lhs.cast().getValues(); - auto rhsElements = rhs.cast().getValues(); + auto lhsElements = llvm::cast(lhs).getValues(); + auto rhsElements = llvm::cast(rhs).getValues(); for (const auto &index : this->getMask().getAsValueRange()) { int64_t i = index.getZExtValue(); if (i >= lhsSize) { @@ -2170,7 +2174,7 @@ if (mask.size() != 1) return failure(); Type resType = VectorType::Builder(v1VectorType).setShape({1}); - if (mask[0].cast().getInt() == 0) + if (llvm::cast(mask[0]).getInt() == 0) rewriter.replaceOpWithNewOp(shuffleOp, resType, shuffleOp.getV1()); else @@ -2242,11 +2246,11 @@ if (!src || !dst || !pos) return {}; - auto dstElements = dst.cast().getValues(); + auto dstElements = llvm::cast(dst).getValues(); SmallVector results(dstElements); - auto attr = pos.dyn_cast(); + auto attr = llvm::dyn_cast(pos); uint64_t posIdx = attr.getInt(); results[posIdx] = src; @@ -2282,7 +2286,7 @@ if (positionAttr.size() > static_cast(destVectorType.getRank())) return emitOpError( "expected position attribute of rank smaller than dest vector rank"); - auto srcVectorType = getSourceType().dyn_cast(); + auto srcVectorType = llvm::dyn_cast(getSourceType()); if (srcVectorType && (static_cast(srcVectorType.getRank()) + positionAttr.size() != static_cast(destVectorType.getRank()))) @@ -2293,7 +2297,7 @@ return emitOpError( "expected position attribute rank to match the dest vector rank"); for (const auto &en : llvm::enumerate(positionAttr)) { - auto attr = en.value().dyn_cast(); + auto attr = llvm::dyn_cast(en.value()); if (!attr || attr.getInt() < 0 || attr.getInt() >= destVectorType.getDimSize(en.index())) return emitOpError("expected position attribute #") @@ -2314,7 +2318,7 @@ LogicalResult matchAndRewrite(InsertOp insertOp, PatternRewriter &rewriter) const override { - auto srcVecType = insertOp.getSourceType().dyn_cast(); + auto srcVecType = llvm::dyn_cast(insertOp.getSourceType()); if (!srcVecType || insertOp.getDestVectorType().getNumElements() != srcVecType.getNumElements()) return failure(); @@ -2372,7 +2376,7 @@ !destVector.hasOneUse()) return failure(); - auto denseDest = vectorDestCst.cast(); + auto denseDest = llvm::cast(vectorDestCst); Value sourceValue = op.getSource(); Attribute sourceCst; @@ -2387,7 +2391,7 @@ linearize(completePositions, computeStrides(destTy.getShape())); SmallVector insertedValues; - if (auto denseSource = sourceCst.dyn_cast()) + if (auto denseSource = llvm::dyn_cast(sourceCst)) llvm::append_range(insertedValues, denseSource.getValues()); else insertedValues.push_back(sourceCst); @@ -2455,7 +2459,7 @@ int64_t max, StringRef attrName, bool halfOpen = true) { for (auto attr : arrayAttr) { - auto val = attr.cast().getInt(); + auto val = llvm::cast(attr).getInt(); auto upper = max; if (!halfOpen) upper += 1; @@ -2476,8 +2480,7 @@ bool halfOpen = true, int64_t min = 0) { for (auto [index, attrDimPair] : llvm::enumerate(llvm::zip_first(arrayAttr, shape))) { - int64_t val = - std::get<0>(attrDimPair).template cast().getInt(); + int64_t val = llvm::cast(std::get<0>(attrDimPair)).getInt(); int64_t max = std::get<1>(attrDimPair); if (!halfOpen) max += 1; @@ -2501,8 +2504,8 @@ assert(arrayAttr2.size() <= shape.size()); for (auto [index, it] : llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) { - auto val1 = std::get<0>(it).template cast().getInt(); - auto val2 = std::get<1>(it).template cast().getInt(); + auto val1 = llvm::cast(std::get<0>(it)).getInt(); + auto val2 = llvm::cast(std::get<1>(it)).getInt(); int64_t max = std::get<2>(it); if (!halfOpen) max += 1; @@ -2643,7 +2646,7 @@ !destVector.hasOneUse()) return failure(); - auto denseDest = vectorDestCst.cast(); + auto denseDest = llvm::cast(vectorDestCst); TypedValue sourceValue = op.getSource(); Attribute sourceCst; @@ -2666,7 +2669,7 @@ // increasing linearized position indices. // Because the destination may have higher dimensionality then the slice, // we keep track of two overlapping sets of positions and offsets. - auto denseSlice = sourceCst.cast(); + auto denseSlice = llvm::cast(sourceCst); auto sliceValuesIt = denseSlice.value_begin(); auto newValues = llvm::to_vector(denseDest.getValues()); SmallVector currDestPosition(offsets.begin(), offsets.end()); @@ -2735,8 +2738,8 @@ if (operandsInfo.size() < 2) return parser.emitError(parser.getNameLoc(), "expected at least 2 operands"); - VectorType vLHS = tLHS.dyn_cast(); - VectorType vRHS = tRHS.dyn_cast(); + VectorType vLHS = llvm::dyn_cast(tLHS); + VectorType vRHS = llvm::dyn_cast(tRHS); if (!vLHS) return parser.emitError(parser.getNameLoc(), "expected vector type for operand #1"); @@ -2771,7 +2774,7 @@ LogicalResult OuterProductOp::verify() { Type tRHS = getOperandTypeRHS(); VectorType vLHS = getOperandVectorTypeLHS(), - vRHS = tRHS.dyn_cast(), + vRHS = llvm::dyn_cast(tRHS), vACC = getOperandVectorTypeACC(), vRES = getResultVectorType(); if (vLHS.getRank() != 1) @@ -2897,7 +2900,7 @@ shape.reserve(vectorType.getRank()); unsigned idx = 0; for (unsigned e = offsets.size(); idx < e; ++idx) - shape.push_back(sizes[idx].cast().getInt()); + shape.push_back(llvm::cast(sizes[idx]).getInt()); for (unsigned e = vectorType.getShape().size(); idx < e; ++idx) shape.push_back(vectorType.getShape()[idx]); @@ -2913,7 +2916,7 @@ auto sizesAttr = getVectorSubscriptAttr(builder, sizes); auto stridesAttr = getVectorSubscriptAttr(builder, strides); result.addTypes( - inferStridedSliceOpResultType(source.getType().cast(), + inferStridedSliceOpResultType(llvm::cast(source.getType()), offsetsAttr, sizesAttr, stridesAttr)); result.addAttribute(getOffsetsAttrStrName(), offsetsAttr); result.addAttribute(getSizesAttrStrName(), sizesAttr); @@ -2967,7 +2970,7 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { // Helper to extract integer out of ArrayAttr. auto getElement = [](ArrayAttr array, int idx) { - return array[idx].cast().getInt(); + return llvm::cast(array[idx]).getInt(); }; ArrayAttr extractOffsets = op.getOffsets(); ArrayAttr extractStrides = op.getStrides(); @@ -3112,7 +3115,7 @@ if (!matchPattern(sourceVector, m_Constant(&vectorCst))) return failure(); - auto splat = vectorCst.dyn_cast(); + auto splat = llvm::dyn_cast(vectorCst); if (!splat) return failure(); @@ -3141,7 +3144,7 @@ return failure(); // The splat case is handled by `StridedSliceSplatConstantFolder`. - auto dense = vectorCst.dyn_cast(); + auto dense = llvm::dyn_cast(vectorCst); if (!dense || dense.isSplat()) return failure(); @@ -3149,7 +3152,7 @@ if (extractStridedSliceOp.hasNonUnitStrides()) return failure(); - auto sourceVecTy = sourceVector.getType().cast(); + auto sourceVecTy = llvm::cast(sourceVector.getType()); ArrayRef sourceShape = sourceVecTy.getShape(); SmallVector sourceStrides = computeStrides(sourceShape); @@ -3201,9 +3204,10 @@ auto broadcast = op.getVector().getDefiningOp(); if (!broadcast) return failure(); - auto srcVecType = broadcast.getSource().getType().dyn_cast(); + auto srcVecType = + llvm::dyn_cast(broadcast.getSource().getType()); unsigned srcRank = srcVecType ? srcVecType.getRank() : 0; - auto dstVecType = op.getType().cast(); + auto dstVecType = llvm::cast(op.getType()); unsigned dstRank = dstVecType.getRank(); unsigned rankDiff = dstRank - srcRank; // Check if the most inner dimensions of the source of the broadcast are the @@ -3269,7 +3273,7 @@ VectorType vectorType, Value source, ValueRange indices, AffineMapAttr permutationMapAttr, /*optional*/ ArrayAttr inBoundsAttr) { - Type elemType = source.getType().cast().getElementType(); + Type elemType = llvm::cast(source.getType()).getElementType(); Value padding = builder.create( result.location, elemType, builder.getZeroAttr(elemType)); build(builder, result, vectorType, source, indices, permutationMapAttr, @@ -3295,7 +3299,7 @@ ValueRange indices, Value padding, std::optional> inBounds) { AffineMap permutationMap = getTransferMinorIdentityMap( - source.getType().cast(), vectorType); + llvm::cast(source.getType()), vectorType); auto permutationMapAttr = AffineMapAttr::get(permutationMap); auto inBoundsAttr = (inBounds && !inBounds.value().empty()) ? builder.getBoolArrayAttr(inBounds.value()) @@ -3311,7 +3315,7 @@ VectorType vectorType, Value source, ValueRange indices, std::optional> inBounds) { - Type elemType = source.getType().cast().getElementType(); + Type elemType = llvm::cast(source.getType()).getElementType(); Value padding = builder.create( result.location, elemType, builder.getZeroAttr(elemType)); build(builder, result, vectorType, source, indices, padding, inBounds); @@ -3356,13 +3360,13 @@ "Use in_bounds instead."); } - if (!shapedType.isa()) + if (!llvm::isa(shapedType)) return op->emitOpError( "requires source to be a memref or ranked tensor type"); auto elementType = shapedType.getElementType(); DataLayout dataLayout = DataLayout::closest(op); - if (auto vectorElementType = elementType.dyn_cast()) { + if (auto vectorElementType = llvm::dyn_cast(elementType)) { // Memref or tensor has vector element type. unsigned sourceVecSize = dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) * @@ -3425,7 +3429,7 @@ << " vs inBounds of size: " << inBounds.size(); for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i) if (permutationMap.getResult(i).isa() && - !inBounds.getValue()[i].cast().getValue()) + !llvm::cast(inBounds.getValue()[i]).getValue()) return op->emitOpError("requires broadcast dimensions to be in-bounds"); } @@ -3440,7 +3444,7 @@ bool elideInBounds = true; if (auto inBounds = op.in_bounds()) { for (auto attr : *inBounds) { - if (attr.template cast().getValue()) { + if (llvm::cast(attr).getValue()) { elideInBounds = false; break; } @@ -3496,10 +3500,10 @@ if (types.size() != 2) return parser.emitError(typesLoc, "requires two types"); auto indexType = builder.getIndexType(); - auto shapedType = types[0].dyn_cast(); - if (!shapedType || !shapedType.isa()) + auto shapedType = llvm::dyn_cast(types[0]); + if (!shapedType || !llvm::isa(shapedType)) return parser.emitError(typesLoc, "requires memref or ranked tensor type"); - VectorType vectorType = types[1].dyn_cast(); + VectorType vectorType = llvm::dyn_cast(types[1]); if (!vectorType) return parser.emitError(typesLoc, "requires vector type"); auto permMapAttrName = TransferReadOp::getPermutationMapAttrStrName(); @@ -3509,7 +3513,7 @@ permMap = getTransferMinorIdentityMap(shapedType, vectorType); result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap)); } else { - permMap = permMapAttr.cast().getValue(); + permMap = llvm::cast(permMapAttr).getValue(); } if (parser.resolveOperand(sourceInfo, shapedType, result.operands) || parser.resolveOperands(indexInfo, indexType, result.operands) || @@ -3517,7 +3521,7 @@ result.operands)) return failure(); if (hasMask.succeeded()) { - if (shapedType.getElementType().dyn_cast()) + if (llvm::dyn_cast(shapedType.getElementType())) return parser.emitError( maskInfo.location, "does not support masks with vector element type"); // Instead of adding the mask type as an op type, compute it based on the @@ -3554,7 +3558,8 @@ getInBounds() ? *getInBounds() : ArrayAttr()))) return failure(); - if (auto sourceVectorElementType = sourceElementType.dyn_cast()) { + if (auto sourceVectorElementType = + llvm::dyn_cast(sourceElementType)) { // Source has vector element type. // Check that 'sourceVectorElementType' and 'paddingType' types match. if (sourceVectorElementType != paddingType) @@ -3647,7 +3652,7 @@ /// %v0 /// ``` static Value foldRAW(TransferReadOp readOp) { - if (!readOp.getShapedType().isa()) + if (!llvm::isa(readOp.getShapedType())) return {}; auto defWrite = readOp.getSource().getDefiningOp(); while (defWrite) { @@ -3682,7 +3687,7 @@ void TransferReadOp::getEffects( SmallVectorImpl> &effects) { - if (getShapedType().isa()) + if (llvm::isa(getShapedType())) effects.emplace_back(MemoryEffects::Read::get(), getSource(), SideEffects::DefaultResource::get()); } @@ -3818,7 +3823,7 @@ LogicalResult matchAndRewrite(TransferReadOp readOp, PatternRewriter &rewriter) const override { if (readOp.hasOutOfBoundsDim() || - !readOp.getShapedType().isa()) + !llvm::isa(readOp.getShapedType())) return failure(); auto defWrite = readOp.getSource().getDefiningOp(); if (!defWrite) @@ -3889,7 +3894,7 @@ AffineMapAttr permutationMapAttr, /*optional*/ Value mask, /*optional*/ ArrayAttr inBoundsAttr) { - Type resultType = dest.getType().dyn_cast(); + Type resultType = llvm::dyn_cast(dest.getType()); build(builder, result, resultType, vector, dest, indices, permutationMapAttr, mask, inBoundsAttr); } @@ -3922,9 +3927,9 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result, Value vector, Value dest, ValueRange indices, std::optional> inBounds) { - auto vectorType = vector.getType().cast(); + auto vectorType = llvm::cast(vector.getType()); AffineMap permutationMap = getTransferMinorIdentityMap( - dest.getType().cast(), vectorType); + llvm::cast(dest.getType()), vectorType); build(builder, result, vector, dest, indices, permutationMap, inBounds); } @@ -3949,11 +3954,11 @@ if (types.size() != 2) return parser.emitError(typesLoc, "requires two types"); auto indexType = builder.getIndexType(); - VectorType vectorType = types[0].dyn_cast(); + VectorType vectorType = llvm::dyn_cast(types[0]); if (!vectorType) return parser.emitError(typesLoc, "requires vector type"); - ShapedType shapedType = types[1].dyn_cast(); - if (!shapedType || !shapedType.isa()) + ShapedType shapedType = llvm::dyn_cast(types[1]); + if (!shapedType || !llvm::isa(shapedType)) return parser.emitError(typesLoc, "requires memref or ranked tensor type"); auto permMapAttrName = TransferWriteOp::getPermutationMapAttrStrName(); auto permMapAttr = result.attributes.get(permMapAttrName); @@ -3962,14 +3967,14 @@ permMap = getTransferMinorIdentityMap(shapedType, vectorType); result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap)); } else { - permMap = permMapAttr.cast().getValue(); + permMap = llvm::cast(permMapAttr).getValue(); } if (parser.resolveOperand(vectorInfo, vectorType, result.operands) || parser.resolveOperand(sourceInfo, shapedType, result.operands) || parser.resolveOperands(indexInfo, indexType, result.operands)) return failure(); if (hasMask.succeeded()) { - if (shapedType.getElementType().dyn_cast()) + if (llvm::dyn_cast(shapedType.getElementType())) return parser.emitError( maskInfo.location, "does not support masks with vector element type"); auto maskType = inferTransferOpMaskType(vectorType, permMap); @@ -3980,7 +3985,7 @@ builder.getDenseI32ArrayAttr( {1, 1, static_cast(indexInfo.size()), static_cast(hasMask.succeeded())})); - return failure(shapedType.isa() && + return failure(llvm::isa(shapedType) && parser.addTypeToList(shapedType, result.types)); } @@ -4052,7 +4057,7 @@ if (write.getTransferRank() == 0) return failure(); auto rankedTensorType = - write.getSource().getType().dyn_cast(); + llvm::dyn_cast(write.getSource().getType()); // If not operating on tensors, bail. if (!rankedTensorType) return failure(); @@ -4119,7 +4124,7 @@ /// ``` static LogicalResult foldWAR(TransferWriteOp write, SmallVectorImpl &results) { - if (!write.getSource().getType().isa()) + if (!llvm::isa(write.getSource().getType())) return failure(); auto read = write.getVector().getDefiningOp(); if (!read) @@ -4149,7 +4154,7 @@ void TransferWriteOp::getEffects( SmallVectorImpl> &effects) { - if (getShapedType().isa()) + if (llvm::isa(getShapedType())) effects.emplace_back(MemoryEffects::Write::get(), getSource(), SideEffects::DefaultResource::get()); } @@ -4184,7 +4189,7 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TransferWriteOp writeOp, PatternRewriter &rewriter) const override { - if (!writeOp.getShapedType().isa()) + if (!llvm::isa(writeOp.getShapedType())) return failure(); vector::TransferWriteOp writeToModify = writeOp; @@ -4439,7 +4444,7 @@ // Checks for vector memrefs. Type memElemTy = memRefTy.getElementType(); - if (auto memVecTy = memElemTy.dyn_cast()) { + if (auto memVecTy = llvm::dyn_cast(memElemTy)) { if (memVecTy != resVecTy) return emitOpError("base memref and result vector types should match"); memElemTy = memVecTy.getElementType(); @@ -4471,7 +4476,7 @@ // Checks for vector memrefs. Type memElemTy = memRefTy.getElementType(); - if (auto memVecTy = memElemTy.dyn_cast()) { + if (auto memVecTy = llvm::dyn_cast(memElemTy)) { if (memVecTy != valueVecTy) return emitOpError( "base memref and valueToStore vector types should match"); @@ -4604,7 +4609,7 @@ VectorType resVType = getVectorType(); ShapedType baseType = getBaseType(); - if (!baseType.isa()) + if (!llvm::isa(baseType)) return emitOpError("requires base to be a memref or ranked tensor type"); if (resVType.getElementType() != baseType.getElementType()) @@ -4864,8 +4869,10 @@ } LogicalResult ShapeCastOp::verify() { - auto sourceVectorType = getSource().getType().dyn_cast_or_null(); - auto resultVectorType = getResult().getType().dyn_cast_or_null(); + auto sourceVectorType = + llvm::dyn_cast_or_null(getSource().getType()); + auto resultVectorType = + llvm::dyn_cast_or_null(getResult().getType()); // Check if source/result are of vector type. if (sourceVectorType && resultVectorType) @@ -4885,8 +4892,8 @@ return otherOp.getSource(); // Only allows valid transitive folding. - VectorType srcType = otherOp.getSource().getType().cast(); - VectorType resultType = getResult().getType().cast(); + VectorType srcType = llvm::cast(otherOp.getSource().getType()); + VectorType resultType = llvm::cast(getResult().getType()); if (srcType.getRank() < resultType.getRank()) { if (!isValidShapeCast(srcType.getShape(), resultType.getShape())) return {}; @@ -4923,11 +4930,11 @@ if (!constantOp) return failure(); // Only handle splat for now. - auto dense = constantOp.getValue().dyn_cast(); + auto dense = llvm::dyn_cast(constantOp.getValue()); if (!dense) return failure(); auto newAttr = - DenseElementsAttr::get(shapeCastOp.getType().cast(), + DenseElementsAttr::get(llvm::cast(shapeCastOp.getType()), dense.getSplatValue()); rewriter.replaceOpWithNewOp(shapeCastOp, newAttr); return success(); @@ -4950,7 +4957,7 @@ return failure(); auto broadcastSourceVectorType = - broadcastOp.getSourceType().dyn_cast(); + llvm::dyn_cast(broadcastOp.getSourceType()); auto broadcastSourceShape = broadcastSourceVectorType ? broadcastSourceVectorType.getShape() : ArrayRef{}; @@ -5029,7 +5036,7 @@ Type srcElemType = getSourceVectorType().getElementType(); Type dstElemType = getResultVectorType().getElementType(); - if (auto floatPack = sourceConstant.dyn_cast()) { + if (auto floatPack = llvm::dyn_cast(sourceConstant)) { if (floatPack.isSplat()) { auto splat = floatPack.getSplatValue(); @@ -5046,11 +5053,11 @@ } } - if (auto intPack = sourceConstant.dyn_cast()) { + if (auto intPack = llvm::dyn_cast(sourceConstant)) { if (intPack.isSplat()) { auto splat = intPack.getSplatValue(); - if (dstElemType.isa()) { + if (llvm::isa(dstElemType)) { uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth(); uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth(); @@ -5075,7 +5082,7 @@ //===----------------------------------------------------------------------===// static SmallVector extractShape(MemRefType memRefType) { - auto vectorType = memRefType.getElementType().dyn_cast(); + auto vectorType = llvm::dyn_cast(memRefType.getElementType()); SmallVector res(memRefType.getShape().begin(), memRefType.getShape().end()); if (vectorType) @@ -5088,7 +5095,7 @@ void TypeCastOp::build(OpBuilder &builder, OperationState &result, Value source) { result.addOperands(source); - MemRefType memRefType = source.getType().cast(); + MemRefType memRefType = llvm::cast(source.getType()); VectorType vectorType = VectorType::get(extractShape(memRefType), getElementTypeOrSelf(getElementTypeOrSelf(memRefType))); @@ -5126,7 +5133,7 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result, Value vector, ArrayRef transp) { - VectorType vt = vector.getType().cast(); + VectorType vt = llvm::cast(vector.getType()); SmallVector transposedShape(vt.getRank()); for (unsigned i = 0; i < transp.size(); ++i) transposedShape[i] = vt.getShape()[transp[i]]; @@ -5170,7 +5177,7 @@ return emitOpError("transposition length mismatch: ") << size; SmallVector seen(rank, false); for (const auto &ta : llvm::enumerate(transpAttr)) { - int64_t i = ta.value().cast().getInt(); + int64_t i = llvm::cast(ta.value()).getInt(); if (i < 0 || i >= rank) return emitOpError("transposition index out of range: ") << i; if (seen[i]) @@ -5239,7 +5246,7 @@ if (!bcastOp) return failure(); - auto srcVectorType = bcastOp.getSourceType().dyn_cast(); + auto srcVectorType = llvm::dyn_cast(bcastOp.getSourceType()); if (!srcVectorType || srcVectorType.getNumElements() == 1) { rewriter.replaceOpWithNewOp( transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource()); @@ -5324,12 +5331,12 @@ //===----------------------------------------------------------------------===// LogicalResult ConstantMaskOp::verify() { - auto resultType = getResult().getType().cast(); + auto resultType = llvm::cast(getResult().getType()); // Check the corner case of 0-D vectors first. if (resultType.getRank() == 0) { if (getMaskDimSizes().size() != 1) return emitError("array attr must have length 1 for 0-D vectors"); - auto dim = getMaskDimSizes()[0].cast().getInt(); + auto dim = llvm::cast(getMaskDimSizes()[0]).getInt(); if (dim != 0 && dim != 1) return emitError("mask dim size must be either 0 or 1 for 0-D vectors"); return success(); @@ -5344,7 +5351,7 @@ auto resultShape = resultType.getShape(); SmallVector maskDimSizes; for (const auto &it : llvm::enumerate(getMaskDimSizes())) { - int64_t attrValue = it.value().cast().getInt(); + int64_t attrValue = llvm::cast(it.value()).getInt(); if (attrValue < 0 || attrValue > resultShape[it.index()]) return emitOpError( "array attr of size out of bounds of vector result dimension size"); @@ -5363,7 +5370,7 @@ // `vector.constant_mask`. In the future, a convention could be established // to decide if a specific dimension value could be considered as "all set". if (resultType.isScalable() && - getMaskDimSizes()[0].cast().getInt() != 0) + llvm::cast(getMaskDimSizes()[0]).getInt() != 0) return emitOpError("expected mask dim sizes for scalable masks to be 0"); return success(); } @@ -5381,14 +5388,14 @@ } LogicalResult CreateMaskOp::verify() { - auto vectorType = getResult().getType().cast(); + auto vectorType = llvm::cast(getResult().getType()); // Verify that an operand was specified for each result vector each dimension. if (vectorType.getRank() == 0) { if (getNumOperands() != 1) return emitOpError( "must specify exactly one operand for 0-D create_mask"); } else if (getNumOperands() != - getResult().getType().cast().getRank()) { + llvm::cast(getResult().getType()).getRank()) { return emitOpError( "must specify an operand for each result vector dimension"); } @@ -5413,7 +5420,7 @@ // CreateMaskOp for scalable vectors can be folded only if all dimensions // are negative or zero. - if (auto vType = createMaskOp.getType().dyn_cast()) { + if (auto vType = llvm::dyn_cast(createMaskOp.getType())) { if (vType.isScalable()) for (auto opDim : createMaskOp.getOperands()) { APInt intVal; @@ -5615,7 +5622,7 @@ "expects result type to match maskable operation result type"); if (llvm::count_if(maskableOp->getResultTypes(), - [](Type t) { return t.isa(); }) > 1) + [](Type t) { return llvm::isa(t); }) > 1) return emitOpError("multiple vector results not supported"); // Mask checks. @@ -5759,7 +5766,7 @@ SmallVector coreAttr = {getWarpSizeAttrName()}; auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName()); - p << "[" << warpSizeAttr.cast().getInt() << "]"; + p << "[" << llvm::cast(warpSizeAttr).getInt() << "]"; if (!getArgs().empty()) p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")"; @@ -5872,8 +5879,8 @@ // If the types matches there is no distribution. if (expanded == distributed) return success(); - auto expandedVecType = expanded.dyn_cast(); - auto distributedVecType = distributed.dyn_cast(); + auto expandedVecType = llvm::dyn_cast(expanded); + auto distributedVecType = llvm::dyn_cast(distributed); if (!expandedVecType || !distributedVecType) return op->emitOpError("expected vector type for distributed operands."); if (expandedVecType.getRank() != distributedVecType.getRank() || @@ -5940,7 +5947,7 @@ case CombiningKind::ADD: if (t1.isIntOrIndex() && tAcc.isIntOrIndex()) result = b.createOrFold(loc, v1, acc); - else if (t1.isa() && tAcc.isa()) + else if (llvm::isa(t1) && llvm::isa(tAcc)) result = b.createOrFold(loc, v1, acc); else llvm_unreachable("invalid value types for ADD reduction"); @@ -5950,12 +5957,12 @@ result = b.createOrFold(loc, v1, acc); break; case CombiningKind::MAXF: - assert(t1.isa() && tAcc.isa() && + assert(llvm::isa(t1) && llvm::isa(tAcc) && "expected float values"); result = b.createOrFold(loc, v1, acc); break; case CombiningKind::MINF: - assert(t1.isa() && tAcc.isa() && + assert(llvm::isa(t1) && llvm::isa(tAcc) && "expected float values"); result = b.createOrFold(loc, v1, acc); break; @@ -5978,7 +5985,7 @@ case CombiningKind::MUL: if (t1.isIntOrIndex() && tAcc.isIntOrIndex()) result = b.createOrFold(loc, v1, acc); - else if (t1.isa() && tAcc.isa()) + else if (llvm::isa(t1) && llvm::isa(tAcc)) result = b.createOrFold(loc, v1, acc); else llvm_unreachable("invalid value types for MUL reduction"); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp @@ -49,7 +49,7 @@ PatternRewriter &rewriter) const override { auto loc = op.getLoc(); VectorType dstType = op.getResultVectorType(); - VectorType srcType = op.getSourceType().dyn_cast(); + VectorType srcType = llvm::dyn_cast(op.getSourceType()); Type eltType = dstType.getElementType(); // Scalar to any vector can use splat. diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -96,9 +96,9 @@ return rewriter.create(loc, lowType, val, posAttr); } // Unroll leading dimensions. - VectorType vType = lowType.cast(); + VectorType vType = llvm::cast(lowType); Type resType = VectorType::Builder(type).dropDim(index); - auto resVectorType = resType.cast(); + auto resVectorType = llvm::cast(resType); Value result = rewriter.create( loc, resVectorType, rewriter.getZeroAttr(resVectorType)); for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { @@ -126,7 +126,7 @@ } // Unroll leading dimensions. Type lowType = VectorType::Builder(type).dropDim(0); - VectorType vType = lowType.cast(); + VectorType vType = llvm::cast(lowType); Type insType = VectorType::Builder(vType).dropDim(0); for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { auto posAttr = rewriter.getI64ArrayAttr(d); @@ -160,7 +160,8 @@ // Only valid for integer types. return std::nullopt; // Special case for fused multiply-add. - if (acc && acc.getType().isa() && kind == CombiningKind::ADD) { + if (acc && llvm::isa(acc.getType()) && + kind == CombiningKind::ADD) { Value fma = rewriter.create(loc, x, y, acc); if (mask) // The fma op doesn't need explicit masking. However, fma ops used in @@ -418,7 +419,7 @@ Value promote(Value v, Type dstElementType) { Type elementType = v.getType(); - auto vecType = elementType.dyn_cast(); + auto vecType = llvm::dyn_cast(elementType); if (vecType) elementType = vecType.getElementType(); if (elementType == dstElementType) @@ -426,7 +427,7 @@ Type promotedType = dstElementType; if (vecType) promotedType = VectorType::get(vecType.getShape(), promotedType); - if (dstElementType.isa()) + if (llvm::isa(dstElementType)) return rewriter.create(loc, promotedType, v); return rewriter.create(loc, promotedType, v); } @@ -438,7 +439,8 @@ if (mask && !maybeMask.has_value()) return failure(); - Type resElementType = res.getType().cast().getElementType(); + Type resElementType = + llvm::cast(res.getType()).getElementType(); for (int64_t k = 0; k < reductionSize; ++k) { Value extractA = rewriter.create(loc, lhs, k); Value extractB = rewriter.create(loc, rhs, k); @@ -684,7 +686,7 @@ return failure(); } - VectorType dstType = op.getResultType().cast(); + VectorType dstType = llvm::cast(op.getResultType()); assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && "Expected dst type of rank 1 or 2"); @@ -695,7 +697,7 @@ // ExtractOp does not allow dynamic indexing, we must unroll explicitly. Value res = rewriter.create(loc, dstType, rewriter.getZeroAttr(dstType)); - bool isInt = dstType.getElementType().isa(); + bool isInt = llvm::isa(dstType.getElementType()); for (unsigned r = 0; r < dstRows; ++r) { Value a = rewriter.create(op.getLoc(), lhs, r); for (unsigned c = 0; c < dstColumns; ++c) { @@ -789,7 +791,7 @@ } else { // If the parallel dimension doesn't exist we will have to broadcast it. lhsDims.push_back( - contractOp.getResultType().cast().getDimSize(i)); + llvm::cast(contractOp.getResultType()).getDimSize(i)); lhsTranspose.push_back(lhsDims.size() - 1); } std::optional rhsDim = @@ -799,7 +801,7 @@ } else { // If the parallel dimension doesn't exist we will have to broadcast it. rhsDims.push_back( - contractOp.getResultType().cast().getDimSize(i)); + llvm::cast(contractOp.getResultType()).getDimSize(i)); rhsTranspose.push_back(rhsDims.size() - 1); } } @@ -969,7 +971,7 @@ Value mask) const { VectorType lhsType = op.getLhsType(); VectorType rhsType = op.getRhsType(); - VectorType resType = op.getResultType().cast(); + VectorType resType = llvm::cast(op.getResultType()); // Find the iterator type index and result index. SmallVector iMap = op.getIndexingMapsArray(); int64_t iterIndex = -1; @@ -1044,10 +1046,10 @@ VectorType lhsType = op.getLhsType(); VectorType rhsType = op.getRhsType(); Type resType = op.getResultType(); - if (resType.isa()) + if (llvm::isa(resType)) return rewriter.notifyMatchFailure(op, "did not expect a VectorType result"); - bool isInt = resType.isa(); + bool isInt = llvm::isa(resType); // Use iterator index 0. int64_t iterIndex = 0; SmallVector iMap = op.getIndexingMapsArray(); @@ -1133,10 +1135,10 @@ auto loc = op.getLoc(); VectorType lhsType = op.getOperandVectorTypeLHS(); - VectorType rhsType = op.getOperandTypeRHS().dyn_cast(); + VectorType rhsType = llvm::dyn_cast(op.getOperandTypeRHS()); VectorType resType = op.getResultVectorType(); Type eltType = resType.getElementType(); - bool isInt = eltType.isa(); + bool isInt = llvm::isa(eltType); Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0]; vector::CombiningKind kind = op.getKind(); @@ -1231,7 +1233,7 @@ return failure(); Type dstElementType = op.getType(); - if (auto vecType = dstElementType.dyn_cast()) + if (auto vecType = llvm::dyn_cast(dstElementType)) dstElementType = vecType.getElementType(); if (elementType != dstElementType) return failure(); @@ -1259,8 +1261,8 @@ return failure(); // At this point lhs and rhs are in row-major. - VectorType lhsType = lhs.getType().cast(); - VectorType rhsType = rhs.getType().cast(); + VectorType lhsType = llvm::cast(lhs.getType()); + VectorType rhsType = llvm::cast(rhs.getType()); int64_t lhsRows = lhsType.getDimSize(0); int64_t lhsColumns = lhsType.getDimSize(1); int64_t rhsColumns = rhsType.getDimSize(1); @@ -1289,7 +1291,7 @@ llvm_unreachable("invalid contraction semantics"); Value res = - elementType.isa() + llvm::isa(elementType) ? static_cast(rew.create(loc, op.getAcc(), mul)) : static_cast( rew.create(loc, op.getAcc(), mul)); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -231,7 +231,7 @@ Value vectorMask = maskableOp.getMaskingOp().getMask(); auto maskCastedType = VectorType::get( vectorShape, - vectorMask.getType().cast().getElementType()); + llvm::cast(vectorMask.getType()).getElementType()); newVectorMask = rewriter.create(loc, maskCastedType, vectorMask); } @@ -413,7 +413,7 @@ srcVectorType.getElementType()); auto accType = VectorType::get(ArrayRef{1}, srcVectorType.getElementType()); - assert(!multiReductionOp.getDestType().isa() && + assert(!llvm::isa(multiReductionOp.getDestType()) && "multi_reduction with a single dimension expects a scalar result"); // If the unique dim is reduced and we insert a parallel in front, we need a @@ -427,7 +427,7 @@ loc, accType, multiReductionOp.getAcc()); Value castMask; if (maskableOp.isMasked()) { - auto maskType = mask.getType().cast(); + auto maskType = llvm::cast(mask.getType()); auto castMaskType = VectorType::get(ArrayRef{1, maskType.getShape().back()}, maskType.getElementType()); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp @@ -48,7 +48,7 @@ PatternRewriter &rewriter) { using vector::CombiningKind; - auto elType = x.getType().cast().getElementType(); + auto elType = llvm::cast(x.getType()).getElementType(); bool isInt = elType.isIntOrIndex(); Value combinedResult{nullptr}; diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -29,7 +29,7 @@ size_t index = 0; for (unsigned pos : permutation) newInBoundsValues[pos] = - attr.getValue()[index++].cast().getValue(); + llvm::cast(attr.getValue()[index++]).getValue(); return builder.getBoolArrayAttr(newInBoundsValues); } @@ -37,7 +37,7 @@ /// dimensions. static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank) { - auto originalVecType = vec.getType().cast(); + auto originalVecType = llvm::cast(vec.getType()); SmallVector newShape(addedRank, 1); newShape.append(originalVecType.getShape().begin(), originalVecType.getShape().end()); @@ -257,7 +257,7 @@ // All the new dimensions added are inbound. SmallVector newInBoundsValues(missingInnerDim.size(), true); for (Attribute attr : op.getInBounds().value().getValue()) { - newInBoundsValues.push_back(attr.cast().getValue()); + newInBoundsValues.push_back(llvm::cast(attr).getValue()); } newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues); } @@ -315,7 +315,7 @@ // In the meantime, lower these to a scalar load when they pop up. if (reducedShapeRank == 0) { Value newRead; - if (op.getShapedType().isa()) { + if (llvm::isa(op.getShapedType())) { newRead = rewriter.create( op.getLoc(), op.getSource(), op.getIndices()); } else { @@ -397,7 +397,7 @@ &broadcastedDims)) return rewriter.notifyMatchFailure(read, "not minor identity + bcast"); - auto memRefType = read.getShapedType().dyn_cast(); + auto memRefType = llvm::dyn_cast(read.getShapedType()); if (!memRefType) return rewriter.notifyMatchFailure(read, "not a memref source"); @@ -418,11 +418,12 @@ // `vector.load` supports vector types as memref's elements only when the // resulting vector type is the same as the element type. auto memrefElTy = memRefType.getElementType(); - if (memrefElTy.isa() && memrefElTy != unbroadcastedVectorType) + if (llvm::isa(memrefElTy) && + memrefElTy != unbroadcastedVectorType) return rewriter.notifyMatchFailure(read, "incompatible element type"); // Otherwise, element types of the memref and the vector must match. - if (!memrefElTy.isa() && + if (!llvm::isa(memrefElTy) && memrefElTy != read.getVectorType().getElementType()) return rewriter.notifyMatchFailure(read, "non-matching element type"); @@ -543,7 +544,7 @@ diag << "permutation map is not minor identity: " << write; }); - auto memRefType = write.getShapedType().dyn_cast(); + auto memRefType = llvm::dyn_cast(write.getShapedType()); if (!memRefType) return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { diag << "not a memref type: " << write; @@ -558,13 +559,14 @@ // `vector.store` supports vector types as memref's elements only when the // type of the vector value being written is the same as the element type. auto memrefElTy = memRefType.getElementType(); - if (memrefElTy.isa() && memrefElTy != write.getVectorType()) + if (llvm::isa(memrefElTy) && + memrefElTy != write.getVectorType()) return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { diag << "elemental type mismatch: " << write; }); // Otherwise, element types of the memref and the vector must match. - if (!memrefElTy.isa() && + if (!llvm::isa(memrefElTy) && memrefElTy != write.getVectorType().getElementType()) return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { diag << "elemental type mismatch: " << write; diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -156,7 +156,7 @@ /// dst[511:384] := SELECT4(v2[511:0], mask[7:6]) static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { - assert(v1.getType().cast().getShape()[0] == 16 && + assert(llvm::cast(v1.getType()).getShape()[0] == 16 && "expected a vector with length=16"); SmallVector shuffleMask; auto appendToMask = [&](int64_t base, uint8_t control) { @@ -291,7 +291,7 @@ vs[0xf] = create4x128BitSuffle(b, t7, tf, 0xdd); auto reshInputType = VectorType::get( - {m, n}, source.getType().cast().getElementType()); + {m, n}, llvm::cast(source.getType()).getElementType()); Value res = b.create(reshInputType, b.getZeroAttr(reshInputType)); for (int64_t i = 0; i < m; ++i) @@ -329,7 +329,7 @@ // Set up convenience transposition table. SmallVector transp; for (auto attr : op.getTransp()) - transp.push_back(attr.cast().getInt()); + transp.push_back(llvm::cast(attr).getInt()); if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) && succeeded(isTranspose2DSlice(op))) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -62,8 +62,9 @@ Value laneId, Value zero) : sequentialVal(sequentialVal), distributedVal(distributedVal), laneId(laneId), zero(zero) { - sequentialVectorType = sequentialVal.getType().dyn_cast(); - distributedVectorType = distributedVal.getType().dyn_cast(); + sequentialVectorType = llvm::dyn_cast(sequentialVal.getType()); + distributedVectorType = + llvm::dyn_cast(distributedVal.getType()); if (sequentialVectorType && distributedVectorType) distributionMap = calculateImplicitMap(sequentialVectorType, distributedVectorType); @@ -89,7 +90,7 @@ "Must store either the preregistered distributed or the " "preregistered sequential value."); // Scalar case can directly use memref.store. - if (!val.getType().isa()) + if (!llvm::isa(val.getType())) return b.create(loc, val, buffer, zero); // Vector case must use vector::TransferWriteOp which will later lower to @@ -131,7 +132,7 @@ Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) { // Scalar case can directly use memref.store. - if (!type.isa()) + if (!llvm::isa(type)) return b.create(loc, buffer, zero); // Other cases must be vector atm. @@ -149,7 +150,7 @@ } SmallVector inBounds(indices.size(), true); return b.create( - loc, type.cast(), buffer, indices, + loc, llvm::cast(type), buffer, indices, ArrayRef(inBounds.begin(), inBounds.end())); } @@ -630,14 +631,14 @@ Location loc = warpOp.getLoc(); for (OpOperand &operand : elementWise->getOpOperands()) { Type targetType; - if (auto vecType = distributedVal.getType().dyn_cast()) { + if (auto vecType = llvm::dyn_cast(distributedVal.getType())) { // If the result type is a vector, the operands must also be vectors. - auto operandType = operand.get().getType().cast(); + auto operandType = llvm::cast(operand.get().getType()); targetType = VectorType::get(vecType.getShape(), operandType.getElementType()); } else { auto operandType = operand.get().getType(); - assert(!operandType.isa() && + assert(!llvm::isa(operandType) && "unexpected yield of vector from op with scalar result type"); targetType = operandType; } @@ -687,7 +688,7 @@ if (!yieldOperand) return failure(); auto constantOp = yieldOperand->get().getDefiningOp(); - auto dense = constantOp.getValue().dyn_cast(); + auto dense = llvm::dyn_cast(constantOp.getValue()); if (!dense) return failure(); unsigned operandIndex = yieldOperand->getOperandNumber(); @@ -737,8 +738,8 @@ SmallVector indices(read.getIndices().begin(), read.getIndices().end()); - auto sequentialType = read.getResult().getType().cast(); - auto distributedType = distributedVal.getType().cast(); + auto sequentialType = llvm::cast(read.getResult().getType()); + auto distributedType = llvm::cast(distributedVal.getType()); AffineMap map = calculateImplicitMap(sequentialType, distributedType); AffineMap indexMap = map.compose(read.getPermutationMap()); OpBuilder::InsertionGuard g(rewriter); @@ -751,8 +752,8 @@ continue; unsigned indexPos = indexExpr.getPosition(); unsigned vectorPos = std::get<1>(it).cast().getPosition(); - int64_t scale = - distributedVal.getType().cast().getDimSize(vectorPos); + int64_t scale = llvm::cast(distributedVal.getType()) + .getDimSize(vectorPos); indices[indexPos] = affine::makeComposedAffineApply( rewriter, read.getLoc(), d0 + scale * d1, {indices[indexPos], warpOp.getLaneid()}); @@ -845,7 +846,7 @@ resultIndex = operand.getOperandNumber(); break; } - auto arg = operand.get().dyn_cast(); + auto arg = llvm::dyn_cast(operand.get()); if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation()) continue; Value warpOperand = warpOp.getArgs()[arg.getArgNumber()]; @@ -874,7 +875,7 @@ auto broadcastOp = operand->get().getDefiningOp(); Location loc = broadcastOp.getLoc(); auto destVecType = - warpOp->getResultTypes()[operandNumber].cast(); + llvm::cast(warpOp->getResultTypes()[operandNumber]); SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( rewriter, warpOp, {broadcastOp.getSource()}, @@ -914,7 +915,8 @@ // Rewrite vector.extract with 1d source to vector.extractelement. if (extractSrcType.getRank() == 1) { assert(extractOp.getPosition().size() == 1 && "expected 1 index"); - int64_t pos = extractOp.getPosition()[0].cast().getInt(); + int64_t pos = + llvm::cast(extractOp.getPosition()[0]).getInt(); rewriter.setInsertionPoint(extractOp); rewriter.replaceOpWithNewOp( extractOp, extractOp.getVector(), @@ -946,8 +948,8 @@ // Find the distributed dimension. There should be exactly one. auto distributedType = - warpOp.getResult(operandNumber).getType().cast(); - auto yieldedType = operand->get().getType().cast(); + llvm::cast(warpOp.getResult(operandNumber).getType()); + auto yieldedType = llvm::cast(operand->get().getType()); int64_t distributedDim = -1; for (int64_t i = 0; i < yieldedType.getRank(); ++i) { if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) { @@ -1083,7 +1085,7 @@ auto insertOp = operand->get().getDefiningOp(); VectorType vecType = insertOp.getDestVectorType(); VectorType distrType = - warpOp.getResult(operandNumber).getType().cast(); + llvm::cast(warpOp.getResult(operandNumber).getType()); bool hasPos = static_cast(insertOp.getPosition()); // Yield destination vector, source scalar and position from warp op. @@ -1171,7 +1173,7 @@ // Rewrite vector.insert with 1d dest to vector.insertelement. if (insertOp.getDestVectorType().getRank() == 1) { assert(insertOp.getPosition().size() == 1 && "expected 1 index"); - int64_t pos = insertOp.getPosition()[0].cast().getInt(); + int64_t pos = llvm::cast(insertOp.getPosition()[0]).getInt(); rewriter.setInsertionPoint(insertOp); rewriter.replaceOpWithNewOp( insertOp, insertOp.getSource(), insertOp.getDest(), @@ -1199,8 +1201,8 @@ // Find the distributed dimension. There should be exactly one. auto distrDestType = - warpOp.getResult(operandNumber).getType().cast(); - auto yieldedType = operand->get().getType().cast(); + llvm::cast(warpOp.getResult(operandNumber).getType()); + auto yieldedType = llvm::cast(operand->get().getType()); int64_t distrDestDim = -1; for (int64_t i = 0; i < yieldedType.getRank(); ++i) { if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) { @@ -1213,7 +1215,7 @@ assert(distrDestDim != -1 && "could not find distributed dimension"); // Compute the distributed source vector type. - VectorType srcVecType = insertOp.getSourceType().cast(); + VectorType srcVecType = llvm::cast(insertOp.getSourceType()); SmallVector distrSrcShape(srcVecType.getShape().begin(), srcVecType.getShape().end()); // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32> @@ -1248,7 +1250,7 @@ int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim); SmallVector newPos = llvm::to_vector( llvm::map_range(insertOp.getPosition(), [](Attribute attr) { - return attr.cast().getInt(); + return llvm::cast(attr).getInt(); })); // tid of inserting lane: pos / elementsPerLane Value insertingLane = rewriter.create( @@ -1337,7 +1339,7 @@ if (!escapingValues.insert(operand->get())) return; Type distType = operand->get().getType(); - if (auto vecType = distType.cast()) { + if (auto vecType = llvm::cast(distType)) { AffineMap map = distributionMapFn(operand->get()); distType = getDistributedType(vecType, map, warpOp.getWarpSize()); } @@ -1359,7 +1361,7 @@ for (OpOperand &yieldOperand : yield->getOpOperands()) { if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) continue; - auto forResult = yieldOperand.get().cast(); + auto forResult = llvm::cast(yieldOperand.get()); newOperands.push_back( newWarpOp.getResult(yieldOperand.getOperandNumber())); yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]); @@ -1463,7 +1465,7 @@ auto reductionOp = cast(yieldOperand->get().getDefiningOp()); - auto vectorType = reductionOp.getVector().getType().cast(); + auto vectorType = llvm::cast(reductionOp.getVector().getType()); // Only rank 1 vectors supported. if (vectorType.getRank() != 1) return rewriter.notifyMatchFailure( @@ -1564,7 +1566,7 @@ // operations from there. for (auto &op : body->without_terminator()) { bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) { - return result.getType().isa(); + return llvm::isa(result.getType()); }); if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody)) opsToMove.insert(&op); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -21,7 +21,7 @@ // Helper that picks the proper sequence for inserting. static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, Value into, int64_t offset) { - auto vectorType = into.getType().cast(); + auto vectorType = llvm::cast(into.getType()); if (vectorType.getRank() > 1) return rewriter.create(loc, from, into, offset); return rewriter.create( @@ -32,7 +32,7 @@ // Helper that picks the proper sequence for extracting. static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, int64_t offset) { - auto vectorType = vector.getType().cast(); + auto vectorType = llvm::cast(vector.getType()); if (vectorType.getRank() > 1) return rewriter.create(loc, vector, offset); return rewriter.create( @@ -134,10 +134,10 @@ } int64_t offset = - op.getOffsets().getValue().front().cast().getInt(); + llvm::cast(op.getOffsets().getValue().front()).getInt(); int64_t size = srcType.getShape().front(); int64_t stride = - op.getStrides().getValue().front().cast().getInt(); + llvm::cast(op.getStrides().getValue().front()).getInt(); auto loc = op.getLoc(); Value res = op.getDest(); @@ -174,7 +174,7 @@ off += stride, ++idx) { // 1. extract the proper subvector (or element) from source Value extractedSource = extractOne(rewriter, loc, op.getSource(), idx); - if (extractedSource.getType().isa()) { + if (llvm::isa(extractedSource.getType())) { // 2. If we have a vector, extract the proper subvector from destination // Otherwise we are at the element level and no need to recurse. Value extractedDest = extractOne(rewriter, loc, op.getDest(), off); @@ -208,11 +208,11 @@ assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets"); int64_t offset = - op.getOffsets().getValue().front().cast().getInt(); + llvm::cast(op.getOffsets().getValue().front()).getInt(); int64_t size = - op.getSizes().getValue().front().cast().getInt(); + llvm::cast(op.getSizes().getValue().front()).getInt(); int64_t stride = - op.getStrides().getValue().front().cast().getInt(); + llvm::cast(op.getStrides().getValue().front()).getInt(); assert(dstType.getElementType().isSignlessIntOrIndexOrFloat()); @@ -254,11 +254,11 @@ return failure(); int64_t offset = - op.getOffsets().getValue().front().cast().getInt(); + llvm::cast(op.getOffsets().getValue().front()).getInt(); int64_t size = - op.getSizes().getValue().front().cast().getInt(); + llvm::cast(op.getSizes().getValue().front()).getInt(); int64_t stride = - op.getStrides().getValue().front().cast().getInt(); + llvm::cast(op.getStrides().getValue().front()).getInt(); Location loc = op.getLoc(); SmallVector elements; @@ -300,11 +300,11 @@ assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets"); int64_t offset = - op.getOffsets().getValue().front().cast().getInt(); + llvm::cast(op.getOffsets().getValue().front()).getInt(); int64_t size = - op.getSizes().getValue().front().cast().getInt(); + llvm::cast(op.getSizes().getValue().front()).getInt(); int64_t stride = - op.getStrides().getValue().front().cast().getInt(); + llvm::cast(op.getStrides().getValue().front()).getInt(); auto loc = op.getLoc(); auto elemType = dstType.getElementType(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -261,7 +261,7 @@ llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; })); Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( targetShape, inputType, offsets, sizes, strides); - return canonicalizeStridedLayout(rankReducedType.cast()); + return canonicalizeStridedLayout(llvm::cast(rankReducedType)); } /// Creates a rank-reducing memref.subview op that drops unit dims from its @@ -269,7 +269,7 @@ static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, mlir::Location loc, Value input) { - MemRefType inputType = input.getType().cast(); + MemRefType inputType = llvm::cast(input.getType()); assert(inputType.hasStaticShape()); SmallVector subViewOffsets(inputType.getRank(), 0); SmallVector subViewStrides(inputType.getRank(), 1); @@ -304,9 +304,9 @@ PatternRewriter &rewriter) const override { auto loc = transferReadOp.getLoc(); Value vector = transferReadOp.getVector(); - VectorType vectorType = vector.getType().cast(); + VectorType vectorType = llvm::cast(vector.getType()); Value source = transferReadOp.getSource(); - MemRefType sourceType = source.getType().dyn_cast(); + MemRefType sourceType = llvm::dyn_cast(source.getType()); // TODO: support tensor types. if (!sourceType || !sourceType.hasStaticShape()) return failure(); @@ -347,9 +347,9 @@ PatternRewriter &rewriter) const override { auto loc = transferWriteOp.getLoc(); Value vector = transferWriteOp.getVector(); - VectorType vectorType = vector.getType().cast(); + VectorType vectorType = llvm::cast(vector.getType()); Value source = transferWriteOp.getSource(); - MemRefType sourceType = source.getType().dyn_cast(); + MemRefType sourceType = llvm::dyn_cast(source.getType()); // TODO: support tensor type. if (!sourceType || !sourceType.hasStaticShape()) return failure(); @@ -406,7 +406,7 @@ /// input starting at `firstDimToCollapse`. static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, Value input, int64_t firstDimToCollapse) { - ShapedType inputType = input.getType().cast(); + ShapedType inputType = llvm::cast(input.getType()); if (inputType.getRank() == 1) return input; SmallVector reassociation; @@ -451,9 +451,9 @@ PatternRewriter &rewriter) const override { auto loc = transferReadOp.getLoc(); Value vector = transferReadOp.getVector(); - VectorType vectorType = vector.getType().cast(); + VectorType vectorType = llvm::cast(vector.getType()); Value source = transferReadOp.getSource(); - MemRefType sourceType = source.getType().dyn_cast(); + MemRefType sourceType = llvm::dyn_cast(source.getType()); // Contiguity check is valid on tensors only. if (!sourceType) return failure(); @@ -481,7 +481,7 @@ Value collapsedSource = collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); MemRefType collapsedSourceType = - collapsedSource.getType().dyn_cast(); + llvm::dyn_cast(collapsedSource.getType()); int64_t collapsedRank = collapsedSourceType.getRank(); assert(collapsedRank == firstContiguousInnerDim + 1); SmallVector dimExprs{ @@ -494,7 +494,7 @@ loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap); flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); rewriter.replaceOpWithNewOp( - transferReadOp, vector.getType().cast(), flatRead); + transferReadOp, llvm::cast(vector.getType()), flatRead); return success(); } }; @@ -511,9 +511,9 @@ PatternRewriter &rewriter) const override { auto loc = transferWriteOp.getLoc(); Value vector = transferWriteOp.getVector(); - VectorType vectorType = vector.getType().cast(); + VectorType vectorType = llvm::cast(vector.getType()); Value source = transferWriteOp.getSource(); - MemRefType sourceType = source.getType().dyn_cast(); + MemRefType sourceType = llvm::dyn_cast(source.getType()); // Contiguity check is valid on tensors only. if (!sourceType) return failure(); @@ -541,7 +541,7 @@ Value collapsedSource = collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); MemRefType collapsedSourceType = - collapsedSource.getType().cast(); + llvm::cast(collapsedSource.getType()); int64_t collapsedRank = collapsedSourceType.getRank(); assert(collapsedRank == firstContiguousInnerDim + 1); SmallVector dimExprs{ @@ -610,7 +610,7 @@ *getConstantIntValue(ofr)); } } - if (xferOp.getSource().getType().isa()) { + if (llvm::isa(xferOp.getSource().getType())) { rewriter.replaceOpWithNewOp(extractOp, xferOp.getSource(), newIndices); } else { @@ -637,7 +637,7 @@ LogicalResult matchAndRewrite(vector::ExtractOp extractOp, PatternRewriter &rewriter) const override { // Only match scalar extracts. - if (extractOp.getType().isa()) + if (llvm::isa(extractOp.getType())) return failure(); auto xferOp = extractOp.getVector().getDefiningOp(); if (!xferOp) @@ -660,7 +660,7 @@ SmallVector newIndices(xferOp.getIndices().begin(), xferOp.getIndices().end()); for (const auto &it : llvm::enumerate(extractOp.getPosition())) { - int64_t offset = it.value().cast().getInt(); + int64_t offset = llvm::cast(it.value()).getInt(); int64_t idx = newIndices.size() - extractOp.getPosition().size() + it.index(); OpFoldResult ofr = affine::makeComposedFoldedAffineApply( @@ -673,7 +673,7 @@ extractOp.getLoc(), *getConstantIntValue(ofr)); } } - if (xferOp.getSource().getType().isa()) { + if (llvm::isa(xferOp.getSource().getType())) { rewriter.replaceOpWithNewOp(extractOp, xferOp.getSource(), newIndices); } else { @@ -714,7 +714,7 @@ xferOp.getVector(), pos); } // Construct a scalar store. - if (xferOp.getSource().getType().isa()) { + if (llvm::isa(xferOp.getSource().getType())) { rewriter.replaceOpWithNewOp( xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); } else { @@ -732,12 +732,12 @@ // Run store to load forwarding first since it can expose more dead store // opportunity. rootOp->walk([&](vector::TransferReadOp read) { - if (read.getShapedType().isa()) + if (llvm::isa(read.getShapedType())) opt.storeToLoadForwarding(read); }); opt.removeDeadOp(); rootOp->walk([&](vector::TransferWriteOp write) { - if (write.getShapedType().isa()) + if (llvm::isa(write.getShapedType())) opt.deadStoreOp(write); }); opt.removeDeadOp(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -190,7 +190,7 @@ Location loc = xferOp.getLoc(); int64_t memrefRank = xferOp.getShapedType().getRank(); // TODO: relax this precondition, will require rank-reducing subviews. - assert(memrefRank == alloc.getType().cast().getRank() && + assert(memrefRank == llvm::cast(alloc.getType()).getRank() && "Expected memref rank to match the alloc rank"); ValueRange leadingIndices = xferOp.indices().take_front(xferOp.getLeadingShapedRank()); @@ -570,9 +570,9 @@ ValueRange{}, b.getI64IntegerAttr(32)); } - MemRefType compatibleMemRefType = - getCastCompatibleMemRefType(xferOp.getShapedType().cast(), - alloc.getType().cast()); + MemRefType compatibleMemRefType = getCastCompatibleMemRefType( + llvm::cast(xferOp.getShapedType()), + llvm::cast(alloc.getType())); if (!compatibleMemRefType) return failure(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -93,9 +93,9 @@ PatternRewriter &rewriter) const override { // Check if 'shapeCastOp' has vector source/result type. auto sourceVectorType = - shapeCastOp.getSource().getType().dyn_cast_or_null(); + llvm::dyn_cast_or_null(shapeCastOp.getSource().getType()); auto resultVectorType = - shapeCastOp.getResult().getType().dyn_cast_or_null(); + llvm::dyn_cast_or_null(shapeCastOp.getResult().getType()); if (!sourceVectorType || !resultVectorType) return failure(); @@ -105,7 +105,7 @@ if (!sourceShapeCastOp) return failure(); auto operandSourceVectorType = - sourceShapeCastOp.getSource().getType().cast(); + llvm::cast(sourceShapeCastOp.getSource().getType()); auto operandResultVectorType = sourceShapeCastOp.getType(); // Check if shape cast operations invert each other. @@ -342,7 +342,7 @@ if (!broadcast) continue; // contractionOp can only take vector as operands. - auto srcType = broadcast.getSourceType().dyn_cast(); + auto srcType = llvm::dyn_cast(broadcast.getSourceType()); if (!srcType || srcType.getRank() == broadcast.getResultVectorType().getRank()) continue; @@ -455,7 +455,7 @@ return failure(); Type castResTy = getElementTypeOrSelf(op->getResult(0)); - if (auto vecTy = bcastOp.getSourceType().dyn_cast()) + if (auto vecTy = llvm::dyn_cast(bcastOp.getSourceType())) castResTy = VectorType::get(vecTy.getShape(), castResTy); auto *castOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), @@ -530,7 +530,7 @@ // This is a constant. Create a reverse transpose op for it. auto vectorType = VectorType::get( srcType.getShape(), - operand.getType().cast().getElementType()); + llvm::cast(operand.getType()).getElementType()); srcValues.push_back(rewriter.create( operand.getLoc(), vectorType, operand, rewriter.getI64ArrayAttr(invOrder))); @@ -539,7 +539,7 @@ auto vectorType = VectorType::get( srcType.getShape(), - op->getResultTypes()[0].cast().getElementType()); + llvm::cast(op->getResultTypes()[0]).getElementType()); Operation *elementwiseOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, vectorType, op->getAttrs()); @@ -692,8 +692,8 @@ newSizes = rewriter.getI64ArrayAttr(sizes); } - SmallVector dims = - llvm::to_vector<4>(extractOp.getType().cast().getShape()); + SmallVector dims = llvm::to_vector<4>( + llvm::cast(extractOp.getType()).getShape()); dims.back() = dims.back() / expandRatio; VectorType newExtractType = VectorType::get(dims, castSrcType.getElementType()); @@ -996,7 +996,7 @@ LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override { auto dstType = op.getType(); - if (dstType.cast().isScalable()) + if (llvm::cast(dstType).isScalable()) return failure(); int64_t rank = dstType.getRank(); if (rank > 1) @@ -1026,7 +1026,7 @@ if (readOp.getMask()) return failure(); - auto srcType = readOp.getSource().getType().dyn_cast(); + auto srcType = llvm::dyn_cast(readOp.getSource().getType()); if (!srcType || !srcType.hasStaticShape()) return failure(); @@ -1060,13 +1060,13 @@ MemRefType resultMemrefType; MemRefLayoutAttrInterface layout = srcType.getLayout(); - if (layout.isa() && layout.isIdentity()) { + if (llvm::isa(layout) && layout.isIdentity()) { resultMemrefType = MemRefType::get( srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), nullptr, srcType.getMemorySpace()); } else { MemRefLayoutAttrInterface updatedLayout; - if (auto strided = layout.dyn_cast()) { + if (auto strided = llvm::dyn_cast(layout)) { auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop)); updatedLayout = StridedLayoutAttr::get(strided.getContext(), @@ -1099,7 +1099,8 @@ loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(), strides); auto permMap = getTransferMinorIdentityMap( - rankedReducedView.getType().cast(), resultTargetVecType); + llvm::cast(rankedReducedView.getType()), + resultTargetVecType); Value result = rewriter.create( loc, resultTargetVecType, rankedReducedView, readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -36,9 +36,9 @@ /// the type of `source`. Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) { - if (source.getType().isa()) + if (llvm::isa(source.getType())) return b.createOrFold(loc, source, dim); - if (source.getType().isa()) + if (llvm::isa(source.getType())) return b.createOrFold(loc, source, dim); llvm_unreachable("Expected MemRefType or TensorType"); } @@ -89,7 +89,7 @@ SmallVector transp; for (auto attr : op.getTransp()) - transp.push_back(attr.cast().getInt()); + transp.push_back(llvm::cast(attr).getInt()); // Check whether the two source vector dimensions that are greater than one // must be transposed with each other so that we can apply one of the 2-D @@ -223,7 +223,7 @@ } return false; } else if (op.getNumResults() == 1) { - if (auto v = op.getResult(0).getType().dyn_cast()) { + if (auto v = llvm::dyn_cast(op.getResult(0).getType())) { superVectorType = v; } else { // Not a vector type. diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp @@ -22,11 +22,11 @@ /// Extracts the "main" vector element type from the given X86Vector operation. template static Type getSrcVectorElementType(OpTy op) { - return op.getSrc().getType().template cast().getElementType(); + return llvm::cast(op.getSrc().getType()).getElementType(); } template <> Type getSrcVectorElementType(Vp2IntersectOp op) { - return op.getA().getType().template cast().getElementType(); + return llvm::cast(op.getA().getType()).getElementType(); } namespace { diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -288,30 +288,29 @@ Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction); template <> Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction) { - auto resultType = mainFunction.getFunctionType() - .cast() - .getReturnType() - .dyn_cast(); + auto resultType = + llvm::dyn_cast(mainFunction.getFunctionType() + .cast() + .getReturnType()); if (!resultType || resultType.getWidth() != 32) return makeStringError("only single i32 function result supported"); return Error::success(); } template <> Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction) { - auto resultType = mainFunction.getFunctionType() - .cast() - .getReturnType() - .dyn_cast(); + auto resultType = + llvm::dyn_cast(mainFunction.getFunctionType() + .cast() + .getReturnType()); if (!resultType || resultType.getWidth() != 64) return makeStringError("only single i64 function result supported"); return Error::success(); } template <> Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction) { - if (!mainFunction.getFunctionType() - .cast() - .getReturnType() - .isa()) + if (!llvm::isa(mainFunction.getFunctionType() + .cast() + .getReturnType())) return makeStringError("only single f32 function result supported"); return Error::success(); } @@ -324,8 +323,7 @@ if (!mainFunction || mainFunction.isExternal()) return makeStringError("entry point not found"); - if (mainFunction.getFunctionType() - .cast() + if (llvm::cast(mainFunction.getFunctionType()) .getNumParams() != 0) return makeStringError("function inputs not supported"); diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -66,14 +66,14 @@ case AffineExprKind::Constant: return expr.cast().getValue(); case AffineExprKind::DimId: - if (auto attr = operandConsts[expr.cast().getPosition()] - .dyn_cast_or_null()) + if (auto attr = llvm::dyn_cast_or_null( + operandConsts[expr.cast().getPosition()])) return attr.getInt(); return std::nullopt; case AffineExprKind::SymbolId: - if (auto attr = operandConsts[numDims + - expr.cast().getPosition()] - .dyn_cast_or_null()) + if (auto attr = llvm::dyn_cast_or_null( + operandConsts[numDims + + expr.cast().getPosition()])) return attr.getInt(); return std::nullopt; } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -91,7 +91,7 @@ // it is a function (avoiding a grammar ambiguity). bool wrapped = op->getNumResults() != 1; if (!wrapped && op->getResult(0).getType() && - op->getResult(0).getType().isa()) + llvm::isa(op->getResult(0).getType())) wrapped = true; if (wrapped) @@ -254,7 +254,7 @@ bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const { return elementsAttrElementLimit && *elementsAttrElementLimit < int64_t(attr.getNumElements()) && - !attr.isa(); + !llvm::isa(attr); } /// Return the size limit for printing large ElementsAttr. @@ -803,8 +803,8 @@ attr.getDialect().printAttribute(attr, *this); // Process the builtin attributes. - } else if (attr.isa()) { + } else if (llvm::isa(attr)) { return; } else if (auto dictAttr = dyn_cast(attr)) { for (const NamedAttribute &nestedAttr : dictAttr.getValue()) { @@ -833,9 +833,9 @@ // Don't print the type if we must elide it, or if it is a None type. if (!elideType) { - if (auto typedAttr = attr.dyn_cast()) { + if (auto typedAttr = llvm::dyn_cast(attr)) { Type attrType = typedAttr.getType(); - if (!attrType.isa()) + if (!llvm::isa(attrType)) printType(attrType); } } @@ -845,10 +845,10 @@ return type.getDialect().printType(type, *this); // Only visit the layout of memref if it isn't the identity. - if (auto memrefTy = type.dyn_cast()) { + if (auto memrefTy = llvm::dyn_cast(type)) { printType(memrefTy.getElementType()); MemRefLayoutAttrInterface layout = memrefTy.getLayout(); - if (!layout.isa() || !layout.isIdentity()) + if (!llvm::isa(layout) || !layout.isIdentity()) printAttribute(memrefTy.getLayout()); if (memrefTy.getMemorySpace()) printAttribute(memrefTy.getMemorySpace()); @@ -1418,7 +1418,7 @@ void SSANameState::numberValuesInRegion(Region ®ion) { auto setBlockArgNameFn = [&](Value arg, StringRef name) { assert(!valueIDs.count(arg) && "arg numbered multiple times"); - assert(arg.cast().getOwner()->getParent() == ®ion && + assert(llvm::cast(arg).getOwner()->getParent() == ®ion && "arg not defined in current region"); setValueName(arg, name); }; @@ -1479,7 +1479,7 @@ setValueName(result, name); // Record the result number for groups not anchored at 0. - if (int resultNo = result.cast().getResultNumber()) + if (int resultNo = llvm::cast(result).getResultNumber()) resultGroups.push_back(resultNo); }; // Operations can customize the printing of block names in OpAsmOpInterface. @@ -1878,7 +1878,7 @@ // Print the child if it isn't unknown. auto childLoc = loc.getChildLoc(); - if (!childLoc.isa()) { + if (!llvm::isa(childLoc)) { os << '('; printLocationInternal(childLoc, pretty); os << ')'; @@ -1891,8 +1891,8 @@ os << "callsite("; printLocationInternal(callee, pretty); if (pretty) { - if (callee.isa()) { - if (caller.isa()) { + if (llvm::isa(callee)) { + if (llvm::isa(caller)) { os << " at "; } else { os << newLine << " at "; @@ -2100,19 +2100,19 @@ AttrTypeElision typeElision) { if (!isa(attr.getDialect())) { printDialectAttribute(attr); - } else if (auto opaqueAttr = attr.dyn_cast()) { + } else if (auto opaqueAttr = llvm::dyn_cast(attr)) { printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(), opaqueAttr.getAttrData()); - } else if (attr.isa()) { + } else if (llvm::isa(attr)) { os << "unit"; return; - } else if (auto dictAttr = attr.dyn_cast()) { + } else if (auto dictAttr = llvm::dyn_cast(attr)) { os << '{'; interleaveComma(dictAttr.getValue(), [&](NamedAttribute attr) { printNamedAttribute(attr); }); os << '}'; - } else if (auto intAttr = attr.dyn_cast()) { + } else if (auto intAttr = llvm::dyn_cast(attr)) { Type intType = intAttr.getType(); if (intType.isSignlessInteger(1)) { os << (intAttr.getValue().getBoolValue() ? "true" : "false"); @@ -2132,24 +2132,24 @@ if (typeElision == AttrTypeElision::May && intType.isSignlessInteger(64)) return; - } else if (auto floatAttr = attr.dyn_cast()) { + } else if (auto floatAttr = llvm::dyn_cast(attr)) { printFloatValue(floatAttr.getValue(), os); // FloatAttr elides the type if F64. if (typeElision == AttrTypeElision::May && floatAttr.getType().isF64()) return; - } else if (auto strAttr = attr.dyn_cast()) { + } else if (auto strAttr = llvm::dyn_cast(attr)) { printEscapedString(strAttr.getValue()); - } else if (auto arrayAttr = attr.dyn_cast()) { + } else if (auto arrayAttr = llvm::dyn_cast(attr)) { os << '['; interleaveComma(arrayAttr.getValue(), [&](Attribute attr) { printAttribute(attr, AttrTypeElision::May); }); os << ']'; - } else if (auto affineMapAttr = attr.dyn_cast()) { + } else if (auto affineMapAttr = llvm::dyn_cast(attr)) { os << "affine_map<"; affineMapAttr.getValue().print(os); os << '>'; @@ -2157,7 +2157,7 @@ // AffineMap always elides the type. return; - } else if (auto integerSetAttr = attr.dyn_cast()) { + } else if (auto integerSetAttr = llvm::dyn_cast(attr)) { os << "affine_set<"; integerSetAttr.getValue().print(os); os << '>'; @@ -2165,17 +2165,18 @@ // IntegerSet always elides the type. return; - } else if (auto typeAttr = attr.dyn_cast()) { + } else if (auto typeAttr = llvm::dyn_cast(attr)) { printType(typeAttr.getValue()); - } else if (auto refAttr = attr.dyn_cast()) { + } else if (auto refAttr = llvm::dyn_cast(attr)) { printSymbolReference(refAttr.getRootReference().getValue(), os); for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) { os << "::"; printSymbolReference(nestedRef.getValue(), os); } - } else if (auto intOrFpEltAttr = attr.dyn_cast()) { + } else if (auto intOrFpEltAttr = + llvm::dyn_cast(attr)) { if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) { printElidedElementsAttr(os); } else { @@ -2184,7 +2185,7 @@ os << '>'; } - } else if (auto strEltAttr = attr.dyn_cast()) { + } else if (auto strEltAttr = llvm::dyn_cast(attr)) { if (printerFlags.shouldElideElementsAttr(strEltAttr)) { printElidedElementsAttr(os); } else { @@ -2193,7 +2194,7 @@ os << '>'; } - } else if (auto sparseEltAttr = attr.dyn_cast()) { + } else if (auto sparseEltAttr = llvm::dyn_cast(attr)) { if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) || printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) { printElidedElementsAttr(os); @@ -2207,9 +2208,9 @@ } os << '>'; } - } else if (auto stridedLayoutAttr = attr.dyn_cast()) { + } else if (auto stridedLayoutAttr = llvm::dyn_cast(attr)) { stridedLayoutAttr.print(os); - } else if (auto denseArrayAttr = attr.dyn_cast()) { + } else if (auto denseArrayAttr = llvm::dyn_cast(attr)) { os << "array<"; printType(denseArrayAttr.getElementType()); if (!denseArrayAttr.empty()) { @@ -2218,20 +2219,21 @@ } os << ">"; return; - } else if (auto resourceAttr = attr.dyn_cast()) { + } else if (auto resourceAttr = + llvm::dyn_cast(attr)) { os << "dense_resource<"; printResourceHandle(resourceAttr.getRawHandle()); os << ">"; - } else if (auto locAttr = attr.dyn_cast()) { + } else if (auto locAttr = llvm::dyn_cast(attr)) { printLocation(locAttr); } else { llvm::report_fatal_error("Unknown builtin attribute"); } // Don't print the type if we must elide it, or if it is a None type. if (typeElision != AttrTypeElision::Must) { - if (auto typedAttr = attr.dyn_cast()) { + if (auto typedAttr = llvm::dyn_cast(attr)) { Type attrType = typedAttr.getType(); - if (!attrType.isa()) { + if (!llvm::isa(attrType)) { os << " : "; printType(attrType); } @@ -2300,10 +2302,10 @@ void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr, bool allowHex) { - if (auto stringAttr = attr.dyn_cast()) + if (auto stringAttr = llvm::dyn_cast(attr)) return printDenseStringElementsAttr(stringAttr); - printDenseIntOrFPElementsAttr(attr.cast(), + printDenseIntOrFPElementsAttr(llvm::cast(attr), allowHex); } @@ -2333,12 +2335,12 @@ return; } - if (ComplexType complexTy = elementType.dyn_cast()) { + if (ComplexType complexTy = llvm::dyn_cast(elementType)) { Type complexElementType = complexTy.getElementType(); // Note: The if and else below had a common lambda function which invoked // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2 // and hence was replaced. - if (complexElementType.isa()) { + if (llvm::isa(complexElementType)) { auto valueIt = attr.value_begin>(); printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { auto complexValue = *(valueIt + index); @@ -2365,7 +2367,7 @@ printDenseIntElement(*(valueIt + index), os, elementType); }); } else { - assert(elementType.isa() && "unexpected element type"); + assert(llvm::isa(elementType) && "unexpected element type"); auto valueIt = attr.value_begin(); printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { printFloatValue(*(valueIt + index), os); @@ -2397,7 +2399,7 @@ if (type.isIntOrIndex()) { printDenseIntElement(value, getStream(), type); } else { - APFloat fltVal(type.cast().getFloatSemantics(), value); + APFloat fltVal(llvm::cast(type).getFloatSemantics(), value); printFloatValue(fltVal, getStream()); } }; @@ -2447,7 +2449,7 @@ interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); }); os << ") -> "; ArrayRef results = funcTy.getResults(); - if (results.size() == 1 && !results[0].isa()) { + if (results.size() == 1 && !llvm::isa(results[0])) { printType(results[0]); } else { os << '('; @@ -2506,7 +2508,7 @@ } printType(memrefTy.getElementType()); MemRefLayoutAttrInterface layout = memrefTy.getLayout(); - if (!layout.isa() || !layout.isIdentity()) { + if (!llvm::isa(layout) || !layout.isIdentity()) { os << ", "; printAttribute(memrefTy.getLayout(), AttrTypeElision::May); } @@ -2580,7 +2582,7 @@ ::printKeywordOrString(attr.getName().strref(), os); // Pretty printing elides the attribute value for unit attributes. - if (attr.getValue().isa()) + if (llvm::isa(attr.getValue())) return; os << " = "; diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -33,7 +33,7 @@ /// Return the bit width which DenseElementsAttr should use for this type. inline size_t getDenseElementBitWidth(Type eltType) { // Align the width for complex to 8 to make storage and interpretation easier. - if (ComplexType comp = eltType.dyn_cast()) + if (ComplexType comp = llvm::dyn_cast(eltType)) return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2; if (eltType.isIndex()) return IndexType::kInternalStorageBitWidth; diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -46,7 +46,9 @@ assert(name.size() != 0 && "expected valid attribute name"); } -StringAttr NamedAttribute::getName() const { return name.cast(); } +StringAttr NamedAttribute::getName() const { + return llvm::cast(name); +} Dialect *NamedAttribute::getNameDialect() const { return getName().getReferencedDialect(); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -316,14 +316,15 @@ } TypedAttr Builder::getZeroAttr(Type type) { - if (type.isa()) + if (llvm::isa(type)) return getFloatAttr(type, 0.0); - if (type.isa()) + if (llvm::isa(type)) return getIndexAttr(0); - if (auto integerType = type.dyn_cast()) - return getIntegerAttr(type, APInt(type.cast().getWidth(), 0)); - if (type.isa()) { - auto vtType = type.cast(); + if (auto integerType = llvm::dyn_cast(type)) + return getIntegerAttr(type, + APInt(llvm::cast(type).getWidth(), 0)); + if (llvm::isa(type)) { + auto vtType = llvm::cast(type); auto element = getZeroAttr(vtType.getElementType()); if (!element) return {}; diff --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp --- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp +++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp @@ -53,7 +53,7 @@ } uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef index) { - ShapedType shapeType = type.cast(); + ShapedType shapeType = llvm::cast(type); assert(isValidIndex(shapeType, index) && "expected valid multi-dimensional index"); diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -300,11 +300,12 @@ LogicalResult FloatAttr::verify(function_ref emitError, Type type, APFloat value) { // Verify that the type is correct. - if (!type.isa()) + if (!llvm::isa(type)) return emitError() << "expected floating point type"; // Verify that the type semantics match that of the value. - if (&type.cast().getFloatSemantics() != &value.getSemantics()) { + if (&llvm::cast(type).getFloatSemantics() != + &value.getSemantics()) { return emitError() << "FloatAttr type doesn't match the type implied by its value"; } @@ -321,11 +322,11 @@ } FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { - return get(ctx, value, {}).cast(); + return llvm::cast(get(ctx, value, {})); } FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) { - return get(value, {}).cast(); + return llvm::cast(get(value, {})); } FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) { @@ -370,14 +371,14 @@ LogicalResult IntegerAttr::verify(function_ref emitError, Type type, APInt value) { - if (IntegerType integerType = type.dyn_cast()) { + if (IntegerType integerType = llvm::dyn_cast(type)) { if (integerType.getWidth() != value.getBitWidth()) return emitError() << "integer type bit width (" << integerType.getWidth() << ") doesn't match value bit width (" << value.getBitWidth() << ")"; return success(); } - if (type.isa()) { + if (llvm::isa(type)) { if (value.getBitWidth() != IndexType::kInternalStorageBitWidth) return emitError() << "value bit width (" << value.getBitWidth() @@ -390,7 +391,7 @@ BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) { auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value)); - return attr.cast(); + return llvm::cast(attr); } //===----------------------------------------------------------------------===// @@ -403,7 +404,7 @@ } bool BoolAttr::classof(Attribute attr) { - IntegerAttr intAttr = attr.dyn_cast(); + IntegerAttr intAttr = llvm::dyn_cast(attr); return intAttr && intAttr.getType().isSignlessInteger(1); } @@ -600,21 +601,21 @@ attr.getAsOpaquePointer(), index) {} Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { - auto owner = getFromOpaquePointer(base).cast(); + auto owner = llvm::cast(getFromOpaquePointer(base)); Type eltTy = owner.getElementType(); - if (auto intEltTy = eltTy.dyn_cast()) + if (auto intEltTy = llvm::dyn_cast(eltTy)) return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); - if (eltTy.isa()) + if (llvm::isa(eltTy)) return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); - if (auto floatEltTy = eltTy.dyn_cast()) { + if (auto floatEltTy = llvm::dyn_cast(eltTy)) { IntElementIterator intIt(owner, index); FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); return FloatAttr::get(eltTy, *floatIt); } - if (auto complexTy = eltTy.dyn_cast()) { + if (auto complexTy = llvm::dyn_cast(eltTy)) { auto complexEltTy = complexTy.getElementType(); ComplexIntElementIterator complexIntIt(owner, index); - if (complexEltTy.isa()) { + if (llvm::isa(complexEltTy)) { auto value = *complexIntIt; auto real = IntegerAttr::get(complexEltTy, value.real()); auto imag = IntegerAttr::get(complexEltTy, value.imag()); @@ -623,14 +624,14 @@ } ComplexFloatElementIterator complexFloatIt( - complexEltTy.cast().getFloatSemantics(), complexIntIt); + llvm::cast(complexEltTy).getFloatSemantics(), complexIntIt); auto value = *complexFloatIt; auto real = FloatAttr::get(complexEltTy, value.real()); auto imag = FloatAttr::get(complexEltTy, value.imag()); return ArrayAttr::get(complexTy.getContext(), ArrayRef{real, imag}); } - if (owner.isa()) { + if (llvm::isa(owner)) { ArrayRef vals = owner.getRawStringData(); return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); } @@ -673,7 +674,7 @@ std::complex, std::complex, std::complex>( attr.getRawData().data(), attr.isSplat(), dataIndex) { - auto complexType = attr.getElementType().cast(); + auto complexType = llvm::cast(attr.getElementType()); bitWidth = getDenseElementBitWidth(complexType.getElementType()); } @@ -713,7 +714,7 @@ IntegerType::SignednessSemantics signedness = IntegerType::Signless> struct DenseArrayAttrIntUtil { static bool checkElementType(Type eltType) { - auto type = eltType.dyn_cast(); + auto type = llvm::dyn_cast(eltType); if (!type || type.getWidth() != width) return false; return type.getSignedness() == signedness; @@ -860,7 +861,7 @@ template bool DenseArrayAttrImpl::classof(Attribute attr) { - if (auto denseArray = attr.dyn_cast()) + if (auto denseArray = llvm::dyn_cast(attr)) return DenseArrayAttrUtil::checkElementType(denseArray.getElementType()); return false; } @@ -884,7 +885,7 @@ /// Method for support type inquiry through isa, cast and dyn_cast. bool DenseElementsAttr::classof(Attribute attr) { - return attr.isa(); + return llvm::isa(attr); } DenseElementsAttr DenseElementsAttr::get(ShapedType type, @@ -894,20 +895,19 @@ Type eltType = type.getElementType(); // Take care complex type case first. - if (auto complexType = eltType.dyn_cast()) { + if (auto complexType = llvm::dyn_cast(eltType)) { if (complexType.getElementType().isIntOrIndex()) { SmallVector> complexValues; complexValues.reserve(values.size()); for (Attribute attr : values) { - assert(attr.isa() && - "expected ArrayAttr for complex"); - auto arrayAttr = attr.cast(); + assert(llvm::isa(attr) && "expected ArrayAttr for complex"); + auto arrayAttr = llvm::cast(attr); assert(arrayAttr.size() == 2 && "expected 2 element for complex"); auto attr0 = arrayAttr[0]; auto attr1 = arrayAttr[1]; complexValues.push_back( - std::complex(attr0.cast().getValue(), - attr1.cast().getValue())); + std::complex(llvm::cast(attr0).getValue(), + llvm::cast(attr1).getValue())); } return DenseElementsAttr::get(type, complexValues); } @@ -915,14 +915,14 @@ SmallVector> complexValues; complexValues.reserve(values.size()); for (Attribute attr : values) { - assert(attr.isa() && "expected ArrayAttr for complex"); - auto arrayAttr = attr.cast(); + assert(llvm::isa(attr) && "expected ArrayAttr for complex"); + auto arrayAttr = llvm::cast(attr); assert(arrayAttr.size() == 2 && "expected 2 element for complex"); auto attr0 = arrayAttr[0]; auto attr1 = arrayAttr[1]; complexValues.push_back( - std::complex(attr0.cast().getValue(), - attr1.cast().getValue())); + std::complex(llvm::cast(attr0).getValue(), + llvm::cast(attr1).getValue())); } return DenseElementsAttr::get(type, complexValues); } @@ -933,9 +933,9 @@ SmallVector stringValues; stringValues.reserve(values.size()); for (Attribute attr : values) { - assert(attr.isa() && + assert(llvm::isa(attr) && "expected string value for non integer/index/float element"); - stringValues.push_back(attr.cast().getValue()); + stringValues.push_back(llvm::cast(attr).getValue()); } return get(type, stringValues); } @@ -949,12 +949,12 @@ llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT)); APInt intVal; for (unsigned i = 0, e = values.size(); i < e; ++i) { - if (auto floatAttr = values[i].dyn_cast()) { + if (auto floatAttr = llvm::dyn_cast(values[i])) { assert(floatAttr.getType() == eltType && "expected float attribute type to equal element type"); intVal = floatAttr.getValue().bitcastToAPInt(); } else { - auto intAttr = values[i].cast(); + auto intAttr = llvm::cast(values[i]); assert(intAttr.getType() == eltType && "expected integer attribute type to equal element type"); intVal = intAttr.getValue(); @@ -1015,8 +1015,8 @@ } DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef> values) { - ComplexType complex = type.getElementType().cast(); - assert(complex.getElementType().isa()); + ComplexType complex = llvm::cast(type.getElementType()); + assert(llvm::isa(complex.getElementType())); assert(hasSameElementsOrSplat(type, values)); size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; ArrayRef intVals(reinterpret_cast(values.data()), @@ -1029,7 +1029,7 @@ // element type of 'type'. DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef values) { - assert(type.getElementType().isa()); + assert(llvm::isa(type.getElementType())); assert(hasSameElementsOrSplat(type, values)); size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); @@ -1037,8 +1037,8 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef> values) { - ComplexType complex = type.getElementType().cast(); - assert(complex.getElementType().isa()); + ComplexType complex = llvm::cast(type.getElementType()); + assert(llvm::isa(complex.getElementType())); assert(hasSameElementsOrSplat(type, values)); ArrayRef apVals(reinterpret_cast(values.data()), values.size() * 2); @@ -1104,11 +1104,11 @@ // Check that the element type is either float or integer or index. if (!isInt) - return type.isa(); + return llvm::isa(type); if (type.isIndex()) return true; - auto intType = type.dyn_cast(); + auto intType = llvm::dyn_cast(type); if (!intType) return false; @@ -1142,8 +1142,8 @@ bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, bool isSigned) const { return ::isValidIntOrFloat( - getElementType().cast().getElementType(), dataEltSize / 2, - isInt, isSigned); + llvm::cast(getElementType()).getElementType(), + dataEltSize / 2, isInt, isSigned); } /// Returns true if this attribute corresponds to a splat, i.e. if all element @@ -1154,7 +1154,7 @@ /// Return if the given complex type has an integer element type. static bool isComplexOfIntType(Type type) { - return type.cast().getElementType().isa(); + return llvm::isa(type.cast().getElementType()); } auto DenseElementsAttr::tryGetComplexIntValues() const @@ -1168,7 +1168,7 @@ auto DenseElementsAttr::tryGetFloatValues() const -> FailureOr> { - auto eltTy = getElementType().dyn_cast(); + auto eltTy = llvm::dyn_cast(getElementType()); if (!eltTy) return failure(); const auto &elementSemantics = eltTy.getFloatSemantics(); @@ -1179,10 +1179,10 @@ auto DenseElementsAttr::tryGetComplexFloatValues() const -> FailureOr> { - auto complexTy = getElementType().dyn_cast(); + auto complexTy = llvm::dyn_cast(getElementType()); if (!complexTy) return failure(); - auto eltTy = complexTy.getElementType().dyn_cast(); + auto eltTy = llvm::dyn_cast(complexTy.getElementType()); if (!eltTy) return failure(); const auto &semantics = eltTy.getFloatSemantics(); @@ -1331,7 +1331,7 @@ bool isInt, bool isSigned) { assert(::isValidIntOrFloat( - type.getElementType().cast().getElementType(), + llvm::cast(type.getElementType()).getElementType(), dataEltSize / 2, isInt, isSigned)); int64_t numElements = data.size() / dataEltSize; @@ -1404,7 +1404,7 @@ ShapedType type) { size_t numElements = type.getNumElements(); Type elementType = type.getElementType(); - if (ComplexType complexTy = elementType.dyn_cast()) { + if (ComplexType complexTy = llvm::dyn_cast(elementType)) { elementType = complexTy.getElementType(); numElements = numElements * 2; } @@ -1470,8 +1470,8 @@ /// Method for supporting type inquiry through isa, cast and dyn_cast. bool DenseFPElementsAttr::classof(Attribute attr) { - if (auto denseAttr = attr.dyn_cast()) - return denseAttr.getType().getElementType().isa(); + if (auto denseAttr = llvm::dyn_cast(attr)) + return llvm::isa(denseAttr.getType().getElementType()); return false; } @@ -1489,7 +1489,7 @@ /// Method for supporting type inquiry through isa, cast and dyn_cast. bool DenseIntElementsAttr::classof(Attribute attr) { - if (auto denseAttr = attr.dyn_cast()) + if (auto denseAttr = llvm::dyn_cast(attr)) return denseAttr.getType().getElementType().isIntOrIndex(); return false; } @@ -1525,7 +1525,7 @@ template struct DenseResourceElementsAttrIntUtil { static bool checkElementType(Type eltType) { - IntegerType type = eltType.dyn_cast(); + IntegerType type = llvm::dyn_cast(eltType); if (!type || type.getWidth() != width) return false; return isSigned ? !type.isUnsigned() : !type.isSigned(); @@ -1582,8 +1582,8 @@ "size mismatch between expected element width and blob size"); assert(DenseResourceAttrUtil::checkElementType(type.getElementType()) && "invalid shape element type for provided type `T`"); - return DenseResourceElementsAttr::get(type, blobName, std::move(blob)) - .template cast>(); + return llvm::cast>( + DenseResourceElementsAttr::get(type, blobName, std::move(blob))); } template @@ -1596,7 +1596,7 @@ template bool DenseResourceElementsAttrBase::classof(Attribute attr) { - auto resourceAttr = attr.dyn_cast(); + auto resourceAttr = llvm::dyn_cast(attr); return resourceAttr && DenseResourceAttrUtil::checkElementType( resourceAttr.getElementType()); } @@ -1624,13 +1624,13 @@ /// Get a zero APFloat for the given sparse attribute. APFloat SparseElementsAttr::getZeroAPFloat() const { - auto eltType = getElementType().cast(); + auto eltType = llvm::cast(getElementType()); return APFloat(eltType.getFloatSemantics()); } /// Get a zero APInt for the given sparse attribute. APInt SparseElementsAttr::getZeroAPInt() const { - auto eltType = getElementType().cast(); + auto eltType = llvm::cast(getElementType()); return APInt::getZero(eltType.getWidth()); } @@ -1639,14 +1639,14 @@ auto eltType = getElementType(); // Handle floating point elements. - if (eltType.isa()) + if (llvm::isa(eltType)) return FloatAttr::get(eltType, 0); // Handle complex elements. - if (auto complexTy = eltType.dyn_cast()) { + if (auto complexTy = llvm::dyn_cast(eltType)) { auto eltType = complexTy.getElementType(); Attribute zero; - if (eltType.isa()) + if (llvm::isa(eltType)) zero = FloatAttr::get(eltType, 0); else // must be integer zero = IntegerAttr::get(eltType, 0); @@ -1655,7 +1655,7 @@ } // Handle string type. - if (getValues().isa()) + if (llvm::isa(getValues())) return StringAttr::get("", eltType); // Otherwise, this is an integer. diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -48,15 +48,15 @@ : OpAsmDialectInterface(dialect), blobManager(mgr) {} AliasResult getAlias(Attribute attr, raw_ostream &os) const override { - if (attr.isa()) { + if (llvm::isa(attr)) { os << "map"; return AliasResult::OverridableAlias; } - if (attr.isa()) { + if (llvm::isa(attr)) { os << "set"; return AliasResult::OverridableAlias; } - if (attr.isa()) { + if (llvm::isa(attr)) { os << "loc"; return AliasResult::OverridableAlias; } @@ -64,7 +64,7 @@ } AliasResult getAlias(Type type, raw_ostream &os) const final { - if (auto tupleType = type.dyn_cast()) { + if (auto tupleType = llvm::dyn_cast(type)) { if (tupleType.size() > 16) { os << "tuple"; return AliasResult::OverridableAlias; @@ -145,7 +145,7 @@ // interface. This needs a linear search, but is called only once per data // layout object construction that is used for repeated queries. for (NamedAttribute attr : getOperation()->getAttrs()) - if (auto spec = attr.getValue().dyn_cast()) + if (auto spec = llvm::dyn_cast(attr.getValue())) return spec; return {}; } @@ -168,7 +168,7 @@ StringRef layoutSpecAttrName; DataLayoutSpecInterface layoutSpec; for (const NamedAttribute &na : (*this)->getAttrs()) { - if (auto spec = na.getValue().dyn_cast()) { + if (auto spec = llvm::dyn_cast(na.getValue())) { if (layoutSpec) { InFlightDiagnostic diag = emitOpError() << "expects at most one data layout attribute"; diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp --- a/mlir/lib/IR/BuiltinDialectBytecode.cpp +++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp @@ -32,7 +32,7 @@ static unsigned getIntegerBitWidth(DialectBytecodeReader &reader, Type type) { if (auto intType = dyn_cast(type)) { return intType.getWidth(); - } else if (type.isa()) { + } else if (llvm::isa(type)) { return IndexType::kInternalStorageBitWidth; } reader.emitError() diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -244,10 +244,10 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) { if (!scale) return VectorType(); - if (auto et = getElementType().dyn_cast()) + if (auto et = llvm::dyn_cast(getElementType())) if (auto scaledEt = et.scaleElementBitwidth(scale)) return VectorType::get(getShape(), scaledEt, getNumScalableDims()); - if (auto et = getElementType().dyn_cast()) + if (auto et = llvm::dyn_cast(getElementType())) if (auto scaledEt = et.scaleElementBitwidth(scale)) return VectorType::get(getShape(), scaledEt, getNumScalableDims()); return VectorType(); @@ -305,8 +305,8 @@ // Note: Non standard/builtin types are allowed to exist within tensor // types. Dialects are expected to verify that tensor types have a valid // element type within that dialect. - return type.isa() || + return llvm::isa(type) || !llvm::isa(type.getDialect()); } @@ -321,7 +321,7 @@ for (int64_t s : shape) if (s < 0 && !ShapedType::isDynamic(s)) return emitError() << "invalid tensor dimension size"; - if (auto v = encoding.dyn_cast_or_null()) + if (auto v = llvm::dyn_cast_or_null(encoding)) if (failed(v.verifyEncoding(shape, elementType, emitError))) return failure(); return checkTensorElementType(emitError, elementType); @@ -426,9 +426,9 @@ if (originalType == candidateReducedType) return SliceVerificationResult::Success; - ShapedType originalShapedType = originalType.cast(); + ShapedType originalShapedType = llvm::cast(originalType); ShapedType candidateReducedShapedType = - candidateReducedType.cast(); + llvm::cast(candidateReducedType); // Rank and size logic is valid for all ShapedTypes. ArrayRef originalShape = originalShapedType.getShape(); @@ -459,7 +459,7 @@ return true; // Supported built-in attributes. - if (memorySpace.isa()) + if (llvm::isa(memorySpace)) return true; // Allow custom dialect attributes. @@ -478,7 +478,7 @@ } Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) { - IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null(); + IntegerAttr intMemorySpace = llvm::dyn_cast_or_null(memorySpace); if (intMemorySpace && intMemorySpace.getValue() == 0) return nullptr; @@ -489,10 +489,10 @@ if (!memorySpace) return 0; - assert(memorySpace.isa() && + assert(llvm::isa(memorySpace) && "Using `getMemorySpaceInteger` with non-Integer attribute"); - return static_cast(memorySpace.cast().getInt()); + return static_cast(llvm::cast(memorySpace).getInt()); } unsigned MemRefType::getMemorySpaceAsInt() const { @@ -786,7 +786,7 @@ SmallVectorImpl &strides, int64_t &offset) { // Happy path: the type uses the strided layout directly. - if (auto strided = t.getLayout().dyn_cast()) { + if (auto strided = llvm::dyn_cast(t.getLayout())) { llvm::append_range(strides, strided.getStrides()); offset = strided.getOffset(); return success(); @@ -834,7 +834,7 @@ /// (i32, tensor, f32, i64) void TupleType::getFlattenedTypes(SmallVectorImpl &types) { for (Type type : getTypes()) { - if (auto nestedTuple = type.dyn_cast()) + if (auto nestedTuple = llvm::dyn_cast(type)) nestedTuple.getFlattenedTypes(types); else types.push_back(type); diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -259,7 +259,7 @@ return; auto &os = llvm::errs(); - if (!diag.getLocation().isa()) + if (!llvm::isa(diag.getLocation())) os << diag.getLocation() << ": "; os << "error: "; @@ -448,7 +448,7 @@ if (!fileLoc) { std::string str; llvm::raw_string_ostream strOS(str); - if (!loc.isa()) + if (!llvm::isa(loc)) strOS << loc << ": "; strOS << message; return mgr.PrintMessage(os, SMLoc(), getDiagKind(kind), strOS.str()); @@ -983,7 +983,7 @@ // Print each diagnostic with the format: // ": : " - if (!diag.getLocation().isa()) + if (!llvm::isa(diag.getLocation())) os << diag.getLocation() << ": "; switch (diag.getSeverity()) { case DiagnosticSeverity::Error: diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp --- a/mlir/lib/IR/ExtensibleDialect.cpp +++ b/mlir/lib/IR/ExtensibleDialect.cpp @@ -474,7 +474,7 @@ LogicalResult ExtensibleDialect::printIfDynamicType(Type type, AsmPrinter &printer) { - if (auto dynType = type.dyn_cast()) { + if (auto dynType = llvm::dyn_cast(type)) { dynType.print(printer); return success(); } @@ -496,7 +496,7 @@ LogicalResult ExtensibleDialect::printIfDynamicAttr(Attribute attribute, AsmPrinter &printer) { - if (auto dynAttr = attribute.dyn_cast()) { + if (auto dynAttr = llvm::dyn_cast(attribute)) { dynAttr.print(printer); return success(); } diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -250,14 +250,14 @@ "Invalid number of attributes."); auto &os = p.getStream(); - bool needsParens = types.size() > 1 || types[0].isa() || - (attrs && !attrs[0].cast().empty()); + bool needsParens = types.size() > 1 || llvm::isa(types[0]) || + (attrs && !llvm::cast(attrs[0]).empty()); if (needsParens) os << '('; llvm::interleaveComma(llvm::seq(0, types.size()), os, [&](size_t i) { p.printType(types[i]); if (attrs) - p.printOptionalAttrDict(attrs[i].cast().getValue()); + p.printOptionalAttrDict(llvm::cast(attrs[i]).getValue()); }); if (needsParens) os << ')'; @@ -278,12 +278,13 @@ if (!isExternal) { ArrayRef attrs; if (argAttrs) - attrs = argAttrs[i].cast().getValue(); + attrs = llvm::cast(argAttrs[i]).getValue(); p.printRegionArgument(body.getArgument(i), attrs); } else { p.printType(argTypes[i]); if (argAttrs) - p.printOptionalAttrDict(argAttrs[i].cast().getValue()); + p.printOptionalAttrDict( + llvm::cast(argAttrs[i]).getValue()); } } diff --git a/mlir/lib/IR/FunctionInterfaces.cpp b/mlir/lib/IR/FunctionInterfaces.cpp --- a/mlir/lib/IR/FunctionInterfaces.cpp +++ b/mlir/lib/IR/FunctionInterfaces.cpp @@ -21,14 +21,14 @@ //===----------------------------------------------------------------------===// static bool isEmptyAttrDict(Attribute attr) { - return attr.cast().empty(); + return llvm::cast(attr).empty(); } DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op, unsigned index) { ArrayAttr attrs = op.getArgAttrsAttr(); DictionaryAttr argAttrs = - attrs ? attrs[index].cast() : DictionaryAttr(); + attrs ? llvm::cast(attrs[index]) : DictionaryAttr(); return argAttrs; } @@ -37,7 +37,7 @@ unsigned index) { ArrayAttr attrs = op.getResAttrsAttr(); DictionaryAttr resAttrs = - attrs ? attrs[index].cast() : DictionaryAttr(); + attrs ? llvm::cast(attrs[index]) : DictionaryAttr(); return resAttrs; } @@ -288,7 +288,7 @@ newArgAttrs.reserve(argAttrs.size()); for (unsigned i = 0, e = argIndices.size(); i < e; ++i) if (!argIndices[i]) - newArgAttrs.emplace_back(argAttrs[i].cast()); + newArgAttrs.emplace_back(llvm::cast(argAttrs[i])); setAllArgAttrDicts(op, newArgAttrs); } @@ -309,7 +309,7 @@ newResultAttrs.reserve(resAttrs.size()); for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) if (!resultIndices[i]) - newResultAttrs.emplace_back(resAttrs[i].cast()); + newResultAttrs.emplace_back(llvm::cast(resAttrs[i])); setAllResultAttrDicts(op, newResultAttrs); } diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp --- a/mlir/lib/IR/Location.cpp +++ b/mlir/lib/IR/Location.cpp @@ -64,8 +64,8 @@ /// Methods for support type inquiry through isa, cast, and dyn_cast. bool LocationAttr::classof(Attribute attr) { - return attr.isa(); + return llvm::isa(attr); } //===----------------------------------------------------------------------===// @@ -101,7 +101,7 @@ } } // Otherwise, only add known locations to the set. - if (!loc.isa()) + if (!llvm::isa(loc)) decomposedLocs.insert(loc); } locs = decomposedLocs.getArrayRef(); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -829,11 +829,11 @@ /// If this is a vector type, or a tensor type, return the scalar element type /// that it is built around, otherwise return the type unmodified. static Type getTensorOrVectorElementType(Type type) { - if (auto vec = type.dyn_cast()) + if (auto vec = llvm::dyn_cast(type)) return vec.getElementType(); // Look through tensor> to find the underlying element type. - if (auto tensor = type.dyn_cast()) + if (auto tensor = llvm::dyn_cast(type)) return getTensorOrVectorElementType(tensor.getElementType()); return type; } @@ -867,7 +867,7 @@ LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { for (auto opType : op->getOperandTypes()) { auto type = getTensorOrVectorElementType(opType); - if (!type.isa()) + if (!llvm::isa(type)) return op->emitOpError("requires a float type"); } return success(); @@ -1102,7 +1102,7 @@ LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { for (auto resultType : op->getResultTypes()) - if (!getTensorOrVectorElementType(resultType).isa()) + if (!llvm::isa(getTensorOrVectorElementType(resultType))) return op->emitOpError() << "requires a floating point type"; return success(); @@ -1169,7 +1169,7 @@ LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { auto isMappableType = [](Type type) { - return type.isa(); + return llvm::isa(type); }; auto resultMappableTypes = llvm::to_vector<1>( llvm::make_filter_range(op->getResultTypes(), isMappableType)); diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -59,7 +59,7 @@ } if (!dictionarySorted.getPointer()) dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs)); - return dictionarySorted.getPointer().cast(); + return llvm::cast(dictionarySorted.getPointer()); } /// Add an attribute with the specified name. @@ -405,18 +405,19 @@ OperandRangeRange::OperandRangeRange(OperandRange operands, Attribute operandSegments) : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0, - operandSegments.cast().size()) {} + llvm::cast(operandSegments).size()) { +} OperandRange OperandRangeRange::join() const { const OwnerT &owner = getBase(); - ArrayRef sizeData = owner.second.cast(); + ArrayRef sizeData = llvm::cast(owner.second); return OperandRange(owner.first, std::accumulate(sizeData.begin(), sizeData.end(), 0)); } OperandRange OperandRangeRange::dereference(const OwnerT &object, ptrdiff_t index) { - ArrayRef sizeData = object.second.cast(); + ArrayRef sizeData = llvm::cast(object.second); uint32_t startIndex = std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); return OperandRange(object.first + startIndex, *(sizeData.begin() + index)); @@ -508,7 +509,7 @@ // Update any of the provided segment attributes. for (OperandSegment &segment : operandSegments) { - auto attr = segment.second.getValue().cast(); + auto attr = llvm::cast(segment.second.getValue()); SmallVector segments(attr.asArrayRef()); segments[segment.first] += diff; segment.second.setValue( @@ -524,7 +525,8 @@ const MutableOperandRange &operands, NamedAttribute operandSegmentAttr) : MutableOperandRangeRange( OwnerT(operands, operandSegmentAttr), 0, - operandSegmentAttr.getValue().cast().size()) {} + llvm::cast(operandSegmentAttr.getValue()).size()) { +} MutableOperandRange MutableOperandRangeRange::join() const { return getBase().first; @@ -537,7 +539,7 @@ MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object, ptrdiff_t index) { ArrayRef sizeData = - object.second.getValue().cast(); + llvm::cast(object.second.getValue()); uint32_t startIndex = std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); return object.first.slice( @@ -782,8 +784,8 @@ auto sortValues = [](ValueRange values) { SmallVector sortedValues = llvm::to_vector(values); llvm::sort(sortedValues, [](Value a, Value b) { - auto aArg = a.dyn_cast(); - auto bArg = b.dyn_cast(); + auto aArg = llvm::dyn_cast(a); + auto bArg = llvm::dyn_cast(b); // Case 1. Both `a` and `b` are `BlockArgument`s. if (aArg && bArg) { diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -459,7 +459,7 @@ // Verify the visibility attribute. if (Attribute vis = op->getAttr(mlir::SymbolTable::getVisibilityAttrName())) { - StringAttr visStrAttr = vis.dyn_cast(); + StringAttr visStrAttr = llvm::dyn_cast(vis); if (!visStrAttr) return op->emitOpError() << "requires visibility attribute '" << mlir::SymbolTable::getVisibilityAttrName() @@ -669,7 +669,7 @@ // If the references are not pointer equal, check to see if `subRef` is a // prefix of `ref`. - if (ref.isa() || + if (llvm::isa(ref) || ref.getRootReference() != subRef.getRootReference()) return false; @@ -789,7 +789,7 @@ /// Generates a new symbol reference attribute with a new leaf reference. static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr, FlatSymbolRefAttr newLeafAttr) { - if (oldAttr.isa()) + if (llvm::isa(oldAttr)) return newLeafAttr; auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences()); nestedRefs.back() = newLeafAttr; diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -22,7 +22,7 @@ using namespace mlir; Type mlir::getElementTypeOrSelf(Type type) { - if (auto st = type.dyn_cast()) + if (auto st = llvm::dyn_cast(type)) return st.getElementType(); return type; } @@ -32,7 +32,7 @@ } Type mlir::getElementTypeOrSelf(Attribute attr) { - if (auto typedAttr = attr.dyn_cast()) + if (auto typedAttr = llvm::dyn_cast(attr)) return getElementTypeOrSelf(typedAttr.getType()); return {}; } @@ -47,7 +47,7 @@ /// dialect and typeData. bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect, StringRef typeData) { - if (auto opaque = type.dyn_cast()) + if (auto opaque = llvm::dyn_cast(type)) return opaque.getDialectNamespace() == dialect && opaque.getTypeData() == typeData; return false; @@ -76,8 +76,8 @@ /// compatible if at least one is dynamic or both are equal. The element type /// does not matter. LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) { - auto sType1 = type1.dyn_cast(); - auto sType2 = type2.dyn_cast(); + auto sType1 = llvm::dyn_cast(type1); + auto sType2 = llvm::dyn_cast(type2); // Either both or neither type should be shaped. if (!sType1) @@ -120,7 +120,7 @@ /// dims are equal. The element type does not matter. LogicalResult mlir::verifyCompatibleShapes(TypeRange types) { auto shapedTypes = llvm::to_vector<8>(llvm::map_range( - types, [](auto type) { return type.template dyn_cast(); })); + types, [](auto type) { return llvm::dyn_cast(type); })); // Return failure if some, but not all are not shaped. Return early if none // are shaped also. if (llvm::none_of(shapedTypes, [](auto t) { return t; })) @@ -132,7 +132,7 @@ bool hasScalableVecTypes = false; bool hasNonScalableVecTypes = false; for (Type t : types) { - auto vType = t.dyn_cast(); + auto vType = llvm::dyn_cast(t); if (vType && vType.isScalable()) hasScalableVecTypes = true; else @@ -167,9 +167,9 @@ } Type OperandElementTypeIterator::mapElement(Value value) const { - return value.getType().cast().getElementType(); + return llvm::cast(value.getType()).getElementType(); } Type ResultElementTypeIterator::mapElement(Value value) const { - return value.getType().cast().getElementType(); + return llvm::cast(value.getType()).getElementType(); } diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp --- a/mlir/lib/IR/Verifier.cpp +++ b/mlir/lib/IR/Verifier.cpp @@ -302,7 +302,7 @@ } // Block argument case. Block *block1 = op.getBlock(); - Block *block2 = operand.cast().getOwner(); + Block *block2 = llvm::cast(operand).getOwner(); Region *region1 = block1->getParent(); Region *region2 = block2->getParent(); Location loc = UnknownLoc::get(op.getContext()); diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -783,7 +783,7 @@ } if (auto dense = dyn_cast(attr)) { if (auto iType = dyn_cast( - dense.getType().cast().getElementType())) { + llvm::cast(dense.getType()).getElementType())) { os << '{'; interleaveComma(dense, os, [&](const APInt &val) { printInt(val, shouldMapToUnsigned(iType.getSignedness())); @@ -792,7 +792,7 @@ return success(); } if (auto iType = dyn_cast( - dense.getType().cast().getElementType())) { + llvm::cast(dense.getType()).getElementType())) { os << '{'; interleaveComma(dense, os, [&](const APInt &val) { printInt(val, false); }); diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp --- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp @@ -21,8 +21,8 @@ /// A utility walker that interrupts if the operation has valid debug /// information. static WalkResult interruptIfValidLocation(Operation *op) { - return op->getLoc().isa() ? WalkResult::advance() - : WalkResult::interrupt(); + return llvm::isa(op->getLoc()) ? WalkResult::advance() + : WalkResult::interrupt(); } DebugTranslation::DebugTranslation(Operation *module, llvm::Module &llvmModule) @@ -45,7 +45,7 @@ if (auto targetTripleAttr = module->getAttr(LLVM::LLVMDialect::getTargetTripleAttrName())) { auto targetTriple = - llvm::Triple(targetTripleAttr.cast().getValue()); + llvm::Triple(llvm::cast(targetTripleAttr).getValue()); if (targetTriple.isKnownWindowsMSVCEnvironment()) { // Dwarf debugging files will be generated by default, unless "CodeView" // is set explicitly. Windows/MSVC should use CodeView instead. @@ -68,8 +68,8 @@ const bool hasCallWithoutDebugInfo = func.walk([&](LLVM::CallOp call) { return call.getLoc()->walk([](Location l) { - return l.isa() ? WalkResult::interrupt() - : WalkResult::advance(); + return llvm::isa(l) ? WalkResult::interrupt() + : WalkResult::advance(); }); }) .wasInterrupted(); @@ -273,7 +273,7 @@ DebugTranslation::translateLoc(Location loc, llvm::DILocalScope *scope, const llvm::DILocation *inlinedAt) { // LLVM doesn't have a representation for unknown. - if (!scope || loc.isa()) + if (!scope || llvm::isa(loc)) return nullptr; // Check for a cached instance. @@ -282,12 +282,12 @@ return existingIt->second; const llvm::DILocation *llvmLoc = nullptr; - if (auto callLoc = loc.dyn_cast()) { + if (auto callLoc = llvm::dyn_cast(loc)) { // For callsites, the caller is fed as the inlinedAt for the callee. const auto *callerLoc = translateLoc(callLoc.getCaller(), scope, inlinedAt); llvmLoc = translateLoc(callLoc.getCallee(), scope, callerLoc); - } else if (auto fileLoc = loc.dyn_cast()) { + } else if (auto fileLoc = llvm::dyn_cast(loc)) { llvm::DILocalScope *locationScope = scope; // Only construct a new DIFile when no local scope is present. This // prioritizes existing DI information when it's present. @@ -300,12 +300,12 @@ fileLoc.getColumn(), locationScope, const_cast(inlinedAt)); - } else if (auto fusedLoc = loc.dyn_cast()) { + } else if (auto fusedLoc = llvm::dyn_cast(loc)) { ArrayRef locations = fusedLoc.getLocations(); // Check for a scope encoded with the location. - if (auto scopedAttr = - fusedLoc.getMetadata().dyn_cast_or_null()) + if (auto scopedAttr = llvm::dyn_cast_or_null( + fusedLoc.getMetadata())) scope = translate(scopedAttr); // For fused locations, merge each of the nodes. @@ -315,10 +315,10 @@ llvmLoc, translateLoc(locIt, scope, inlinedAt)); } - } else if (auto nameLoc = loc.dyn_cast()) { + } else if (auto nameLoc = llvm::dyn_cast(loc)) { llvmLoc = translateLoc(nameLoc.getChildLoc(), scope, inlinedAt); - } else if (auto opaqueLoc = loc.dyn_cast()) { + } else if (auto opaqueLoc = llvm::dyn_cast(loc)) { llvmLoc = translateLoc(opaqueLoc.getFallbackLocation(), scope, inlinedAt); } else { llvm_unreachable("unknown location kind"); diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -231,9 +231,9 @@ Attribute attr = it.value(); if (!attr) continue; - DictionaryAttr dAttr = attr.cast(); - TypeAttr tAttr = - dAttr.get(InlineAsmOp::getElementTypeAttrName()).cast(); + DictionaryAttr dAttr = llvm::cast(attr); + TypeAttr tAttr = llvm::cast( + dAttr.get(InlineAsmOp::getElementTypeAttrName())); llvm::AttrBuilder b(moduleTranslation.getLLVMContext()); llvm::Type *ty = moduleTranslation.convertType(tAttr.getValue()); b.addTypeAttr(llvm::Attribute::ElementType, ty); diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -162,7 +162,7 @@ ->addOperand(llvmMetadataNode); }; if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) { - if (!attribute.getValue().dyn_cast()) + if (!llvm::dyn_cast(attribute.getValue())) return failure(); SmallVector values = extractFromI64ArrayAttr(attribute.getValue()); @@ -172,7 +172,7 @@ if (values.size() > 2) generateMetadata(values[2], NVVM::NVVMDialect::getMaxntidZName()); } else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) { - if (!attribute.getValue().dyn_cast()) + if (!llvm::dyn_cast(attribute.getValue())) return failure(); SmallVector values = extractFromI64ArrayAttr(attribute.getValue()); @@ -183,10 +183,10 @@ generateMetadata(values[2], NVVM::NVVMDialect::getReqntidZName()); } else if (attribute.getName() == NVVM::NVVMDialect::getMinctasmAttrName()) { - auto value = attribute.getValue().dyn_cast(); + auto value = llvm::dyn_cast(attribute.getValue()); generateMetadata(value.getInt(), "minctasm"); } else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) { - auto value = attribute.getValue().dyn_cast(); + auto value = llvm::dyn_cast(attribute.getValue()); generateMetadata(value.getInt(), "maxnreg"); } else if (attribute.getName() == NVVM::NVVMDialect::getKernelFuncAttrName()) { diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp @@ -107,7 +107,7 @@ dataPtr = builder.CreateExtractValue(dataValue, kPtrPosInDataDescriptor); dataSize = builder.CreateExtractValue(dataValue, kSizePosInDataDescriptor); - } else if (data.getType().isa()) { + } else if (llvm::isa(data.getType())) { dataPtrBase = dataValue; dataPtr = dataValue; dataSize = accBuilder->getSizeInBytes(dataValue); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -367,7 +367,7 @@ if (criticalOp.getNameAttr()) { // The verifiers in OpenMP Dialect guarentee that all the pointers are // non-null - auto symbolRef = criticalOp.getNameAttr().cast(); + auto symbolRef = llvm::cast(criticalOp.getNameAttr()); auto criticalDeclareOp = SymbolTable::lookupNearestSymbolFrom(criticalOp, symbolRef); @@ -389,7 +389,8 @@ for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) { if (container.getReductionVars()[i] != reduction.getAccumulator()) continue; - reductionSymbol = (*container.getReductions())[i].cast(); + reductionSymbol = + llvm::cast((*container.getReductions())[i]); break; } assert(reductionSymbol && @@ -704,8 +705,8 @@ for (auto dep : llvm::zip(taskOp.getDependVars(), taskOp.getDepends()->getValue())) { llvm::omp::RTLDependenceKindTy type; - switch ( - std::get<1>(dep).cast().getValue()) { + switch (llvm::cast(std::get<1>(dep)) + .getValue()) { case mlir::omp::ClauseTaskDepend::taskdependin: type = llvm::omp::RTLDependenceKindTy::DepIn; break; @@ -1379,7 +1380,7 @@ llvm::Value *mapOpPtr; llvm::Value *mapOpSize; - if (mapOp.getType().isa()) { + if (llvm::isa(mapOp.getType())) { mapOpPtrBase = mapOpValue; mapOpPtr = mapOpValue; mapOpSize = ompBuilder->getSizeInBytes(mapOpValue); @@ -1410,7 +1411,8 @@ {builder.getInt32(0), builder.getInt32(index)}); builder.CreateStore(mapOpSize, sizeGEP); - mapTypeFlags.push_back(mapTypeOp.dyn_cast().getInt()); + mapTypeFlags.push_back( + llvm::dyn_cast(mapTypeOp).getInt()); llvm::Constant *mapName = mlir::LLVM::createMappingInformation(mapOp.getLoc(), *ompBuilder); mapNames.push_back(mapName); @@ -1445,7 +1447,7 @@ if (auto constOp = mlir::dyn_cast( devId.getDefiningOp())) if (auto intAttr = - constOp.getValue().dyn_cast()) + llvm::dyn_cast(constOp.getValue())) deviceID = intAttr.getInt(); numMapOperands = dataOp.getMapOperands().size(); @@ -1464,7 +1466,7 @@ if (auto constOp = mlir::dyn_cast( devId.getDefiningOp())) if (auto intAttr = - constOp.getValue().dyn_cast()) + llvm::dyn_cast(constOp.getValue())) deviceID = intAttr.getInt(); numMapOperands = enterDataOp.getMapOperands().size(); @@ -1483,7 +1485,7 @@ if (auto constOp = mlir::dyn_cast( devId.getDefiningOp())) if (auto intAttr = - constOp.getValue().dyn_cast()) + llvm::dyn_cast(constOp.getValue())) deviceID = intAttr.getInt(); numMapOperands = exitDataOp.getMapOperands().size(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp @@ -109,7 +109,7 @@ auto func = dyn_cast(op); if (!func) return failure(); - auto value = attribute.getValue().dyn_cast(); + auto value = llvm::dyn_cast(attribute.getValue()); if (!value) return failure(); @@ -125,7 +125,7 @@ auto func = dyn_cast(op); if (!func) return failure(); - auto value = attribute.getValue().dyn_cast(); + auto value = llvm::dyn_cast(attribute.getValue()); if (!value) return failure(); @@ -142,7 +142,7 @@ auto func = dyn_cast(op); if (!func) return failure(); - auto value = attribute.getValue().dyn_cast(); + auto value = llvm::dyn_cast(attribute.getValue()); if (!value) return failure(); llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); diff --git a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp --- a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp @@ -190,7 +190,7 @@ void LoopAnnotationConversion::convertLocation(FusedLoc location) { auto localScopeAttr = - location.getMetadata().dyn_cast_or_null(); + llvm::dyn_cast_or_null(location.getMetadata()); if (!localScopeAttr) return; auto *localScope = dyn_cast( diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -630,7 +630,7 @@ if (!type) return nullptr; - if (type.isa()) + if (llvm::isa(type)) return type; // LLVM vectors can only contain scalars. @@ -647,12 +647,12 @@ } // LLVM arrays can contain other arrays or vectors. - if (auto arrayType = type.dyn_cast()) { + if (auto arrayType = llvm::dyn_cast(type)) { // Recover the nested array shape. SmallVector shape; shape.push_back(arrayType.getNumElements()); - while (arrayType.getElementType().isa()) { - arrayType = arrayType.getElementType().cast(); + while (llvm::isa(arrayType.getElementType())) { + arrayType = llvm::cast(arrayType.getElementType()); shape.push_back(arrayType.getNumElements()); } @@ -710,12 +710,12 @@ // Convert constant data to a dense elements attribute. if (auto *cd = dyn_cast(value)) { Type type = convertType(cd->getElementType()); - auto attrType = getStdTypeForAttr(convertType(cd->getType())) - .dyn_cast_or_null(); + auto attrType = llvm::dyn_cast_or_null( + getStdTypeForAttr(convertType(cd->getType()))); if (!attrType) return nullptr; - if (type.isa()) { + if (llvm::isa(type)) { SmallVector values; values.reserve(cd->getNumElements()); for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i) @@ -723,7 +723,7 @@ return DenseElementsAttr::get(attrType, values); } - if (type.isa()) { + if (llvm::isa(type)) { SmallVector values; values.reserve(cd->getNumElements()); for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i) @@ -737,8 +737,8 @@ // Unpack constant aggregates to create dense elements attribute whenever // possible. Return nullptr (failure) otherwise. if (isa(value)) { - auto outerType = getStdTypeForAttr(convertType(value->getType())) - .dyn_cast_or_null(); + auto outerType = llvm::dyn_cast_or_null( + getStdTypeForAttr(convertType(value->getType()))); if (!outerType) return nullptr; @@ -746,8 +746,8 @@ SmallVector shape; for (unsigned i = 0, e = value->getNumOperands(); i < e; ++i) { - auto nested = getConstantAsAttr(value->getAggregateElement(i)) - .dyn_cast_or_null(); + auto nested = llvm::dyn_cast_or_null( + getConstantAsAttr(value->getAggregateElement(i))); if (!nested) return nullptr; @@ -921,7 +921,7 @@ // Convert constants that can be represented as attributes. if (Attribute attr = getConstantAsAttr(constant)) { Type type = convertType(constant->getType()); - if (auto symbolRef = attr.dyn_cast()) { + if (auto symbolRef = llvm::dyn_cast(attr)) { return builder.create(loc, type, symbolRef.getValue()) .getResult(); } @@ -998,7 +998,7 @@ // Generate an UndefOp as root value and insert the aggregate elements. Type rootType = convertType(constant->getType()); - bool isArrayOrStruct = rootType.isa(); + bool isArrayOrStruct = llvm::isa(rootType); assert((isArrayOrStruct || LLVM::isCompatibleVectorType(rootType)) && "unrecognized aggregate type"); Value root = builder.create(loc, rootType); @@ -1558,7 +1558,7 @@ clearBlockAndValueMapping(); auto functionType = - convertType(func->getFunctionType()).dyn_cast(); + llvm::dyn_cast(convertType(func->getFunctionType())); if (func->isIntrinsic() && iface.isConvertibleIntrinsic(func->getIntrinsicID())) return success(); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -73,7 +73,7 @@ if (!key) continue; if (key.getValue() == DLTIDialect::kDataLayoutEndiannessKey) { - auto value = entry.getValue().cast(); + auto value = llvm::cast(entry.getValue()); bool isLittleEndian = value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle; layoutStream << "-" << (isLittleEndian ? "e" : "E"); @@ -81,7 +81,7 @@ continue; } if (key.getValue() == DLTIDialect::kDataLayoutAllocaMemorySpaceKey) { - auto value = entry.getValue().cast(); + auto value = llvm::cast(entry.getValue()); uint64_t space = value.getValue().getZExtValue(); // Skip the default address space. if (space == 0) @@ -91,7 +91,7 @@ continue; } if (key.getValue() == DLTIDialect::kDataLayoutStackAlignmentKey) { - auto value = entry.getValue().cast(); + auto value = llvm::cast(entry.getValue()); uint64_t alignment = value.getValue().getZExtValue(); // Skip the default stack alignment. if (alignment == 0) @@ -112,14 +112,14 @@ if (!type) continue; // Data layout for the index type is irrelevant at this point. - if (type.isa()) + if (llvm::isa(type)) continue; layoutStream << "-"; LogicalResult result = llvm::TypeSwitch(type) .Case([&](Type type) -> LogicalResult { - if (auto intType = type.dyn_cast()) { + if (auto intType = llvm::dyn_cast(type)) { if (intType.getSignedness() != IntegerType::Signless) return emitError(*loc) << "unsupported data layout for non-signless integer " @@ -250,7 +250,7 @@ // Compute the shape of all dimensions but the innermost. Note that the // innermost dimension may be that of the vector element type. - bool hasVectorElementType = type.getElementType().isa(); + bool hasVectorElementType = llvm::isa(type.getElementType()); unsigned numAggregates = denseElementsAttr.getNumElements() / (hasVectorElementType ? 1 @@ -261,7 +261,7 @@ // Handle the case of vector splat, LLVM has special support for it. if (denseElementsAttr.isSplat() && - (type.isa() || hasVectorElementType)) { + (llvm::isa(type) || hasVectorElementType)) { llvm::Constant *splatValue = LLVM::detail::getLLVMConstant( innermostLLVMType, denseElementsAttr.getSplatValue(), loc, moduleTranslation); @@ -277,8 +277,8 @@ // In case of non-splat, create a constructor for the innermost constant from // a piece of raw data. std::function buildCstData; - if (type.isa()) { - auto vectorElementType = type.getElementType().dyn_cast(); + if (llvm::isa(type)) { + auto vectorElementType = llvm::dyn_cast(type.getElementType()); if (vectorElementType && vectorElementType.getRank() == 1) { buildCstData = [&](StringRef data) { return llvm::ConstantDataVector::getRaw( @@ -290,7 +290,7 @@ innermostLLVMType); }; } - } else if (type.isa()) { + } else if (llvm::isa(type)) { buildCstData = [&](StringRef data) { return llvm::ConstantDataVector::getRaw(data, type.getShape().back(), innermostLLVMType); @@ -326,7 +326,7 @@ if (!attr) return llvm::UndefValue::get(llvmType); if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) { - auto arrayAttr = attr.dyn_cast(); + auto arrayAttr = llvm::dyn_cast(attr); if (!arrayAttr || arrayAttr.size() != 2) { emitError(loc, "expected struct type to be a complex number"); return nullptr; @@ -344,11 +344,11 @@ } // For integer types, we allow a mismatch in sizes as the index type in // MLIR might have a different size than the index type in the LLVM module. - if (auto intAttr = attr.dyn_cast()) + if (auto intAttr = llvm::dyn_cast(attr)) return llvm::ConstantInt::get( llvmType, intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth())); - if (auto floatAttr = attr.dyn_cast()) { + if (auto floatAttr = llvm::dyn_cast(attr)) { if (llvmType != llvm::Type::getFloatingPointTy(llvmType->getContext(), floatAttr.getValue().getSemantics())) { @@ -357,10 +357,10 @@ } return llvm::ConstantFP::get(llvmType, floatAttr.getValue()); } - if (auto funcAttr = attr.dyn_cast()) + if (auto funcAttr = llvm::dyn_cast(attr)) return llvm::ConstantExpr::getBitCast( moduleTranslation.lookupFunction(funcAttr.getValue()), llvmType); - if (auto splatAttr = attr.dyn_cast()) { + if (auto splatAttr = llvm::dyn_cast(attr)) { llvm::Type *elementType; uint64_t numElements; bool isScalable = false; @@ -401,13 +401,13 @@ // Try using raw elements data if possible. if (llvm::Constant *result = - convertDenseElementsAttr(loc, attr.dyn_cast(), + convertDenseElementsAttr(loc, llvm::dyn_cast(attr), llvmType, moduleTranslation)) { return result; } // Fall back to element-by-element construction otherwise. - if (auto elementsAttr = attr.dyn_cast()) { + if (auto elementsAttr = llvm::dyn_cast(attr)) { assert(elementsAttr.getShapedType().hasStaticShape()); assert(!elementsAttr.getShapedType().getShape().empty() && "unexpected empty elements attribute shape"); @@ -428,7 +428,7 @@ return result; } - if (auto stringAttr = attr.dyn_cast()) { + if (auto stringAttr = llvm::dyn_cast(attr)) { return llvm::ConstantDataArray::get( moduleTranslation.getLLVMContext(), ArrayRef{stringAttr.getValue().data(), @@ -685,7 +685,8 @@ if (op.getValueOrNull()) { // String attributes are treated separately because they cannot appear as // in-function constants and are thus not supported by getLLVMConstant. - if (auto strAttr = op.getValueOrNull().dyn_cast_or_null()) { + if (auto strAttr = + llvm::dyn_cast_or_null(op.getValueOrNull())) { cst = llvm::ConstantDataArray::getString( llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false); type = cst->getType(); @@ -763,10 +764,11 @@ ctorOp ? llvm::appendToGlobalCtors : llvm::appendToGlobalDtors; for (auto symbolAndPriority : range) { llvm::Function *f = lookupFunction( - std::get<0>(symbolAndPriority).cast().getValue()); + llvm::cast(std::get<0>(symbolAndPriority)) + .getValue()); appendGlobalFn( *llvmModule, f, - std::get<1>(symbolAndPriority).cast().getInt(), + llvm::cast(std::get<1>(symbolAndPriority)).getInt(), /*Data=*/nullptr); } } @@ -830,20 +832,20 @@ return success(); for (Attribute attr : *attributes) { - if (auto stringAttr = attr.dyn_cast()) { + if (auto stringAttr = llvm::dyn_cast(attr)) { if (failed( checkedAddLLVMFnAttribute(loc, llvmFunc, stringAttr.getValue()))) return failure(); continue; } - auto arrayAttr = attr.dyn_cast(); + auto arrayAttr = llvm::dyn_cast(attr); if (!arrayAttr || arrayAttr.size() != 2) return emitError(loc) << "expected 'passthrough' to contain string or array attributes"; - auto keyAttr = arrayAttr[0].dyn_cast(); - auto valueAttr = arrayAttr[1].dyn_cast(); + auto keyAttr = llvm::dyn_cast(arrayAttr[0]); + auto valueAttr = llvm::dyn_cast(arrayAttr[1]); if (!keyAttr || !valueAttr) return emitError(loc) << "expected arrays within 'passthrough' to contain two strings"; @@ -985,7 +987,8 @@ // Convert result attributes. if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) { - DictionaryAttr resultAttrs = allResultAttrs[0].cast(); + DictionaryAttr resultAttrs = + llvm::cast(allResultAttrs[0]); llvmFunc->addRetAttrs(convertParameterAttrs(resultAttrs)); } @@ -1133,7 +1136,7 @@ return; } - SymbolRefAttr tagRef = tagRefs[0].cast(); + SymbolRefAttr tagRef = llvm::cast(tagRefs[0]); llvm::MDNode *node = getTBAANode(op, tagRef); inst->setMetadata(llvm::LLVMContext::MD_tbaa, node); } @@ -1192,7 +1195,8 @@ // The type references are in 1, 3, 5, etc. positions. unsigned opNum = 1; for (Attribute typeAttr : tdOp.getMembers()) { - refNames.push_back(typeAttr.cast().getValue()); + refNames.push_back( + llvm::cast(typeAttr).getValue()); operandIndices.push_back(opNum); opNum += 2; } @@ -1299,7 +1303,8 @@ auto llvmModule = std::make_unique(name, llvmContext); if (auto dataLayoutAttr = m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName())) { - llvmModule->setDataLayout(dataLayoutAttr.cast().getValue()); + llvmModule->setDataLayout( + llvm::cast(dataLayoutAttr).getValue()); } else { FailureOr llvmDataLayout(llvm::DataLayout("")); if (auto iface = dyn_cast(m)) { @@ -1319,7 +1324,8 @@ } if (auto targetTripleAttr = m->getAttr(LLVM::LLVMDialect::getTargetTripleAttrName())) - llvmModule->setTargetTriple(targetTripleAttr.cast().getValue()); + llvmModule->setTargetTriple( + llvm::cast(targetTripleAttr).getValue()); // Inject declarations for `malloc` and `free` functions that can be used in // memref allocation/deallocation coming from standard ops lowering. diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp --- a/mlir/lib/Transforms/Mem2Reg.cpp +++ b/mlir/lib/Transforms/Mem2Reg.cpp @@ -99,7 +99,7 @@ info(std::move(info)) { #ifndef NDEBUG auto isResultOrNewBlockArgument = [&]() { - if (BlockArgument arg = slot.ptr.dyn_cast()) + if (BlockArgument arg = llvm::dyn_cast(slot.ptr)) return arg.getOwner()->getParentOp() == allocator; return slot.ptr.getDefiningOp() == allocator; }; diff --git a/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp b/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp --- a/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp +++ b/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp @@ -32,7 +32,7 @@ void TestMemRefStrideCalculation::runOnOperation() { llvm::outs() << "Testing: " << getOperation().getName() << "\n"; getOperation().walk([&](memref::AllocOp allocOp) { - auto memrefType = allocOp.getResult().getType().cast(); + auto memrefType = llvm::cast(allocOp.getResult().getType()); int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(memrefType, strides, offset))) { diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -81,7 +81,7 @@ return WalkResult::skip(); } Value value = op->getOperand(0); - if (value.getType().isa() != + if (llvm::isa(value.getType()) != !op->hasAttrOfType("dim")) { // Op should have "dim" attribute if and only if the operand is an // index-typed value. @@ -119,7 +119,7 @@ if (reifyToFuncArgs) { // Reify in terms of function block arguments. stopCondition = stopCondition = [](Value v, std::optional d) { - auto bbArg = v.dyn_cast(); + auto bbArg = llvm::dyn_cast(v); if (!bbArg) return false; return isa( @@ -166,7 +166,8 @@ return WalkResult::skip(); } Value constOp = rewriter.create( - op->getLoc(), reified->get().cast().getInt()); + op->getLoc(), + llvm::cast(reified->get()).getInt()); rewriter.replaceOp(op, constOp); return WalkResult::skip(); } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -38,9 +38,9 @@ bool changed = false; for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { for (OpOperand &opOperand : linalgOp->getOpOperands()) { - if (opOperand.get().getType().isa()) + if (llvm::isa(opOperand.get().getType())) continue; - if (opOperand.get().getType().isa()) { + if (llvm::isa(opOperand.get().getType())) { // Tile and Fuse tensor input. if (opOperand.getOperandNumber() >= linalgOp.getNumDpsInputs()) continue; diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -75,7 +75,7 @@ if (parser.parseRSquare() || parser.parseGreater()) return Attribute(); return parser.getChecked( - parser.getContext(), type.cast(), elements); + parser.getContext(), llvm::cast(type), elements); } void TestI64ElementsAttr::print(AsmPrinter &printer) const { diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -287,18 +287,18 @@ for (DataLayoutEntryInterface entry : params) { // This is for testing purposes only, so assert well-formedness. assert(entry.isTypeEntry() && "unexpected identifier entry"); - assert(entry.getKey().get().isa() && + assert(llvm::isa(entry.getKey().get()) && "wrong type passed in"); - auto array = entry.getValue().dyn_cast(); + auto array = llvm::dyn_cast(entry.getValue()); assert(array && array.getValue().size() == 2 && "expected array of two elements"); - auto kind = array.getValue().front().dyn_cast(); + auto kind = llvm::dyn_cast(array.getValue().front()); (void)kind; assert(kind && (kind.getValue() == "size" || kind.getValue() == "alignment" || kind.getValue() == "preferred") && "unexpected kind"); - assert(array.getValue().back().isa()); + assert(llvm::isa(array.getValue().back())); } return success(); } @@ -306,10 +306,11 @@ unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params, StringRef expectedKind) const { for (DataLayoutEntryInterface entry : params) { - ArrayRef pair = entry.getValue().cast().getValue(); - StringRef kind = pair.front().cast().getValue(); + ArrayRef pair = + llvm::cast(entry.getValue()).getValue(); + StringRef kind = llvm::cast(pair.front()).getValue(); if (kind == expectedKind) - return pair.back().cast().getValue().getZExtValue(); + return llvm::cast(pair.back()).getValue().getZExtValue(); } return 1; } @@ -466,7 +467,7 @@ if (succeeded(printIfDynamicType(type, printer))) return; - auto rec = type.cast(); + auto rec = llvm::cast(type); printer << "test_rec<" << rec.getName(); if (!stack.contains(rec)) { printer << ", "; diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -109,10 +109,10 @@ mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply( transform::TransformResults &results, transform::TransformState &state) { if (getOperation()->getNumOperands() != 0) { - results.set(getResult().cast(), + results.set(llvm::cast(getResult()), getOperation()->getOperand(0).getDefiningOp()); } else { - results.set(getResult().cast(), getOperation()); + results.set(llvm::cast(getResult()), getOperation()); } return DiagnosedSilenceableFailure::success(); } @@ -127,7 +127,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceValueHandleToSelfOperand::apply( transform::TransformResults &results, transform::TransformState &state) { - results.setValues(getOut().cast(), getIn()); + results.setValues(llvm::cast(getOut()), getIn()); return DiagnosedSilenceableFailure::success(); } @@ -249,13 +249,13 @@ for (Value value : values) { std::string note; llvm::raw_string_ostream os(note); - if (auto arg = value.dyn_cast()) { + if (auto arg = llvm::dyn_cast(value)) { os << "a block argument #" << arg.getArgNumber() << " in block #" << std::distance(arg.getOwner()->getParent()->begin(), arg.getOwner()->getIterator()) << " in region #" << arg.getOwner()->getParent()->getRegionNumber(); } else { - os << "an op result #" << value.cast().getResultNumber(); + os << "an op result #" << llvm::cast(value).getResultNumber(); } InFlightDiagnostic diag = ::emitRemark(value.getLoc()) << getMessage(); diag.attachNote() << "value handle points to " << os.str(); @@ -317,7 +317,7 @@ getOperation()))) return DiagnosedSilenceableFailure::definiteFailure(); if (getNumResults() > 0) - results.set(getResult(0).cast(), getOperation()); + results.set(llvm::cast(getResult(0)), getOperation()); return DiagnosedSilenceableFailure::success(); } @@ -339,7 +339,7 @@ transform::TransformState &state) { ArrayRef payloadOps = state.getPayloadOps(getTarget()); auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps)); - results.set(getResult().cast(), reversedOps); + results.set(llvm::cast(getResult()), reversedOps); return DiagnosedSilenceableFailure::success(); } @@ -443,7 +443,8 @@ DiagnosedSilenceableFailure mlir::test::TestCopyPayloadOp::apply(transform::TransformResults &results, transform::TransformState &state) { - results.set(getCopy().cast(), state.getPayloadOps(getHandle())); + results.set(llvm::cast(getCopy()), + state.getPayloadOps(getHandle())); return DiagnosedSilenceableFailure::success(); } @@ -472,7 +473,7 @@ DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload( Location loc, ArrayRef payload) const { for (Attribute attr : payload) { - auto integerAttr = attr.dyn_cast(); + auto integerAttr = llvm::dyn_cast(attr); if (integerAttr && integerAttr.getType().isSignlessInteger(32)) continue; return emitSilenceableError(loc) @@ -534,7 +535,7 @@ if (Value param = getParam()) { values = llvm::to_vector( llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t { - return attr.cast().getValue().getLimitedValue( + return llvm::cast(attr).getValue().getLimitedValue( UINT32_MAX); })); } @@ -544,7 +545,7 @@ llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute { return builder.getI32IntegerAttr(value + getAddendum()); })); - results.setParams(getResult().cast(), result); + results.setParams(llvm::cast(getResult()), result); return DiagnosedSilenceableFailure::success(); } @@ -562,7 +563,7 @@ }); return builder.getI32IntegerAttr(count); })); - results.setParams(getResult().cast(), result); + results.setParams(llvm::cast(getResult()), result); return DiagnosedSilenceableFailure::success(); } @@ -570,12 +571,12 @@ mlir::test::TestProduceIntegerParamWithTypeOp::apply( transform::TransformResults &results, transform::TransformState &state) { Attribute zero = IntegerAttr::get(getType(), 0); - results.setParams(getResult().cast(), zero); + results.setParams(llvm::cast(getResult()), zero); return DiagnosedSilenceableFailure::success(); } LogicalResult mlir::test::TestProduceIntegerParamWithTypeOp::verify() { - if (!getType().isa()) { + if (!llvm::isa(getType())) { return emitOpError() << "expects an integer type"; } return success(); @@ -618,7 +619,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply( transform::TransformResults &results, transform::TransformState &state) { SmallVector null({nullptr}); - results.set(getOut().cast(), null); + results.set(llvm::cast(getOut()), null); return DiagnosedSilenceableFailure::success(); } @@ -630,7 +631,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results, transform::TransformState &state) { - results.setParams(getOut().cast(), Attribute()); + results.setParams(llvm::cast(getOut()), Attribute()); return DiagnosedSilenceableFailure::success(); } @@ -642,7 +643,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply(transform::TransformResults &results, transform::TransformState &state) { - results.setValues(getOut().cast(), Value()); + results.setValues(llvm::cast(getOut()), Value()); return DiagnosedSilenceableFailure::success(); } @@ -662,7 +663,7 @@ DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply( transform::TransformResults &results, transform::TransformState &state) { - results.set(getOut().cast(), state.getPayloadOps(getIn())); + results.set(llvm::cast(getOut()), state.getPayloadOps(getIn())); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -89,7 +89,7 @@ auto extract = dyn_cast(users); if (!extract) return std::nullopt; - auto vecType = extract.getResult().getType().cast(); + auto vecType = llvm::cast(extract.getResult().getType()); if (dstVec && dstVec != vecType) return std::nullopt; dstVec = vecType; @@ -430,7 +430,7 @@ static constexpr int64_t kSharedMemorySpace = 3; // Compute type of shared memory buffer. MemRefType memrefType; - if (auto vectorType = type.dyn_cast()) { + if (auto vectorType = llvm::dyn_cast(type)) { memrefType = MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {}, kSharedMemorySpace); @@ -535,7 +535,7 @@ // Create a map (d0, d1) -> (d1) to distribute along the inner // dimension. Once we support n-d distribution we can add more // complex cases. - VectorType vecType = val.getType().dyn_cast(); + VectorType vecType = llvm::dyn_cast(val.getType()); int64_t vecRank = vecType ? vecType.getRank() : 0; OpBuilder builder(val.getContext()); if (vecRank == 0) @@ -642,9 +642,10 @@ if (op->getName().getStringRef() != "test_create_broadcast") return; auto targetShape = - op->getResult(0).getType().cast().getShape(); + llvm::cast(op->getResult(0).getType()).getShape(); auto arrayAttr = - op->getAttr("broadcast_dims").cast().asArrayRef(); + llvm::cast(op->getAttr("broadcast_dims")) + .asArrayRef(); llvm::SetVector broadcastedDims; broadcastedDims.insert(arrayAttr.begin(), arrayAttr.end()); OpBuilder b(op); diff --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp --- a/mlir/test/python/lib/PythonTestCAPI.cpp +++ b/mlir/test/python/lib/PythonTestCAPI.cpp @@ -16,7 +16,7 @@ python_test::PythonTestDialect) bool mlirAttributeIsAPythonTestTestAttribute(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirPythonTestTestAttributeGet(MlirContext context) { @@ -24,7 +24,7 @@ } bool mlirTypeIsAPythonTestTestType(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirPythonTestTestTypeGet(MlirContext context) { diff --git a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp --- a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp +++ b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp @@ -36,18 +36,18 @@ ASSERT_EQ(subElementTypes.size(), 4U); // !llvm.ptr> - ASSERT_TRUE(subElementTypes[0].isa()); + ASSERT_TRUE(llvm::isa(subElementTypes[0])); // !llvm.struct<"bar",...> - auto structType = subElementTypes[1].dyn_cast(); + auto structType = llvm::dyn_cast(subElementTypes[1]); ASSERT_TRUE(bool(structType)); ASSERT_TRUE(structType.getName().equals("bar")); // !llvm.ptr> - ASSERT_TRUE(subElementTypes[2].isa()); + ASSERT_TRUE(llvm::isa(subElementTypes[2])); // !llvm.struct<"foo",...> - structType = subElementTypes[3].dyn_cast(); + structType = llvm::dyn_cast(subElementTypes[3]); ASSERT_TRUE(bool(structType)); ASSERT_TRUE(structType.getName().equals("foo")); } diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp --- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp +++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp @@ -181,8 +181,8 @@ : public TestExternalFallbackTypeInterface::FallbackModel< TestExternalFallbackTypeVectorModel> { unsigned getBitwidth(Type type) const { - IntegerType elementType = - dyn_cast_or_null(type.cast().getElementType()); + IntegerType elementType = dyn_cast_or_null( + llvm::cast(type).getElementType()); return elementType ? elementType.getWidth() : 0; } }; diff --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp --- a/mlir/unittests/TableGen/EnumsGenTest.cpp +++ b/mlir/unittests/TableGen/EnumsGenTest.cpp @@ -175,7 +175,7 @@ mlir::Type intType = mlir::IntegerType::get(&ctx, 32); mlir::Attribute intAttr = mlir::IntegerAttr::get(intType, 5); - EXPECT_TRUE(intAttr.isa()); + EXPECT_TRUE(llvm::isa(intAttr)); EXPECT_EQ(intAttr, enumAttr); } @@ -186,10 +186,10 @@ mlir::Attribute intAttr = mlir::IntegerAttr::get( intType, static_cast(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3)); - EXPECT_TRUE(intAttr.isa()); - EXPECT_TRUE(intAttr.isa()); + EXPECT_TRUE(llvm::isa(intAttr)); + EXPECT_TRUE(llvm::isa(intAttr)); intAttr = mlir::IntegerAttr::get( intType, static_cast(BitEnumWithGroup::Bits0To3) | (1u << 6)); - EXPECT_FALSE(intAttr.isa()); + EXPECT_FALSE(llvm::isa(intAttr)); }