diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp --- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp +++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp @@ -59,11 +59,11 @@ } unsigned firstInputIndex, lastInputIndex; if (region) { - firstInputIndex = inputs[0].cast().getArgNumber(); - lastInputIndex = inputs.back().cast().getArgNumber(); + firstInputIndex = cast(inputs[0]).getArgNumber(); + lastInputIndex = cast(inputs.back()).getArgNumber(); } else { - firstInputIndex = inputs[0].cast().getResultNumber(); - lastInputIndex = inputs.back().cast().getResultNumber(); + firstInputIndex = cast(inputs[0]).getResultNumber(); + lastInputIndex = cast(inputs.back()).getResultNumber(); } if (firstInputIndex > inputIndex || lastInputIndex < inputIndex) { output.push_back(inputValue); @@ -186,9 +186,9 @@ } --maxDepth; - if (BlockArgument arg = value.dyn_cast()) + if (BlockArgument arg = dyn_cast(value)) return collectUnderlyingAddressValues(arg, maxDepth, visited, output); - collectUnderlyingAddressValues(value.cast(), maxDepth, visited, + collectUnderlyingAddressValues(cast(value), maxDepth, visited, output); } @@ -216,10 +216,10 @@ Operation *&allocScopeOp) { // Try to get a memory effect interface for the parent operation. Operation *op; - if (BlockArgument arg = value.dyn_cast()) + if (BlockArgument arg = dyn_cast(value)) op = arg.getOwner()->getParentOp(); else - op = value.cast().getOwner(); + op = cast(value).getOwner(); MemoryEffectOpInterface interface = dyn_cast(op); if (!interface) return failure(); @@ -305,7 +305,7 @@ if (rhsParentOp->isProperAncestor(lhsAllocScope)) return AliasResult::NoAlias; if (rhsParentOp == lhsAllocScope) { - BlockArgument rhsArg = rhs.dyn_cast(); + BlockArgument rhsArg = dyn_cast(rhs); if (rhsArg && rhs.getParentBlock()->isEntryBlock()) return AliasResult::NoAlias; } diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -94,7 +94,7 @@ })); auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) { - auto result = v.dyn_cast(); + auto result = dyn_cast(v); if (!result) return; assert(llvm::is_contained(op->getResults(), result)); @@ -139,7 +139,7 @@ })); auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) { - auto arg = v.dyn_cast(); + auto arg = dyn_cast(v); if (!arg) return; if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg)) @@ -179,7 +179,7 @@ if (loopBound.has_value()) { if (loopBound->is()) { if (auto bound = - loopBound->get().dyn_cast_or_null()) + dyn_cast_or_null(loopBound->get())) return bound.getValue(); } else if (auto value = loopBound->dyn_cast()) { const IntegerValueRangeLattice *lattice = diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -240,7 +240,7 @@ if (inputs.size() != lattices.size()) { if (point.dyn_cast()) { if (!inputs.empty()) - firstIndex = inputs.front().cast().getResultNumber(); + firstIndex = cast(inputs.front()).getResultNumber(); visitNonControlFlowArgumentsImpl( branch, RegionSuccessor( @@ -248,7 +248,7 @@ lattices, firstIndex); } else { if (!inputs.empty()) - firstIndex = inputs.front().cast().getArgNumber(); + firstIndex = cast(inputs.front()).getArgNumber(); Region *region = point.get()->getParent(); visitNonControlFlowArgumentsImpl( branch, diff --git a/mlir/lib/Analysis/Liveness.cpp b/mlir/lib/Analysis/Liveness.cpp --- a/mlir/lib/Analysis/Liveness.cpp +++ b/mlir/lib/Analysis/Liveness.cpp @@ -184,7 +184,7 @@ if (Operation *defOp = value.getDefiningOp()) currentBlock = defOp->getBlock(); else - currentBlock = value.cast().getOwner(); + currentBlock = cast(value).getOwner(); toProcess.push_back(currentBlock); visited.insert(currentBlock); @@ -280,7 +280,7 @@ if (value.getDefiningOp()) os << "val_" << valueIds[value]; else { - auto blockArg = value.cast(); + auto blockArg = cast(value); os << "arg" << blockArg.getArgNumber() << "@" << blockIds[blockArg.getOwner()]; } @@ -404,7 +404,7 @@ Operation *endOfLiveRange = nullptr; // If it's a live in or a block argument, then the start is the beginning // of the block. - if (isLiveIn(value) || value.isa()) + if (isLiveIn(value) || isa(value)) startOfLiveRange = &block->front(); else startOfLiveRange = block->findAncestorOpInBlock(*startOfLiveRange); diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -95,7 +95,7 @@ if (auto *definingOp = operand.getDefiningOp()) { if (backwardSlice->count(definingOp) == 0) getBackwardSliceImpl(definingOp, backwardSlice, filter); - } else if (auto blockArg = operand.dyn_cast()) { + } else if (auto blockArg = dyn_cast(operand)) { Block *block = blockArg.getOwner(); Operation *parentOp = block->getParentOp(); // TODO: determine whether we want to recurse backward into the other @@ -132,7 +132,7 @@ getBackwardSlice(definingOp, backwardSlice, filter, inclusive); return; } - Operation *bbAargOwner = root.cast().getOwner()->getParentOp(); + Operation *bbAargOwner = cast(root).getOwner()->getParentOp(); getBackwardSlice(bbAargOwner, backwardSlice, filter, inclusive); } diff --git a/mlir/lib/AsmParser/AsmParserState.cpp b/mlir/lib/AsmParser/AsmParserState.cpp --- a/mlir/lib/AsmParser/AsmParserState.cpp +++ b/mlir/lib/AsmParser/AsmParserState.cpp @@ -73,7 +73,7 @@ for (auto &it : *opAndUseMapIt.second) { symbolOps.clear(); if (failed(symbolTable.lookupSymbolIn( - opAndUseMapIt.first, it.first.cast(), symbolOps))) + opAndUseMapIt.first, cast(it.first), symbolOps))) continue; for (ArrayRef useRange : it.second) { @@ -301,7 +301,7 @@ } // Otherwise, this is a block argument. - BlockArgument arg = value.cast(); + BlockArgument arg = cast(value); auto existingIt = impl->blocksToIdx.find(arg.getOwner()); assert(existingIt != impl->blocksToIdx.end() && "expected valid block definition for block argument"); diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -348,7 +348,7 @@ else if (!(type = parseType())) return nullptr; } - if (!type.isa()) + if (!isa(type)) return (emitError("floating point value not valid for specified type"), nullptr); return FloatAttr::get(type, isNegative ? -*val : *val); @@ -416,7 +416,7 @@ return nullptr; } - if (auto floatType = type.dyn_cast()) { + if (auto floatType = dyn_cast(type)) { std::optional result; if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative, floatType.getFloatSemantics(), @@ -425,7 +425,7 @@ return FloatAttr::get(floatType, *result); } - if (!type.isa()) + if (!isa(type)) return emitError(loc, "integer literal not valid for specified type"), nullptr; @@ -543,7 +543,7 @@ // Check to see if we parse the literal from a hex string. if (hexStorage && - (eltType.isIntOrIndexOrFloat() || eltType.isa())) + (eltType.isIntOrIndexOrFloat() || isa(eltType))) return getHexAttr(loc, type); // Check that the parsed storage size has the same number of elements to the @@ -563,7 +563,7 @@ // Handle complex types in the specific element type cases below. bool isComplex = false; - if (ComplexType complexTy = eltType.dyn_cast()) { + if (ComplexType complexTy = dyn_cast(eltType)) { eltType = complexTy.getElementType(); isComplex = true; } @@ -583,7 +583,7 @@ return DenseElementsAttr::get(type, intValues); } // Handle floating point types. - if (FloatType floatTy = eltType.dyn_cast()) { + if (FloatType floatTy = dyn_cast(eltType)) { std::vector floatValues; if (failed(getFloatAttrElements(loc, floatTy, floatValues))) return nullptr; @@ -711,7 +711,7 @@ /// Build a Dense attribute with hex data for the given type. DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) { Type elementType = type.getElementType(); - if (!elementType.isIntOrIndexOrFloat() && !elementType.isa()) { + if (!elementType.isIntOrIndexOrFloat() && !isa(elementType)) { p.emitError(loc) << "expected floating-point, integer, or complex element type, got " << elementType; @@ -904,7 +904,7 @@ Token token = p.getToken(); std::optional result; - auto floatType = type.cast(); + auto floatType = cast(type); if (p.consumeIf(Token::integer)) { // Parse an integer literal as a float. if (p.parseFloatFromIntegerLiteral(result, token, isNegative, @@ -1025,7 +1025,7 @@ return nullptr; } - ShapedType shapedType = attrType.dyn_cast(); + ShapedType shapedType = dyn_cast(attrType); if (!shapedType) { emitError(typeLoc, "`dense_resource` expected a shaped type"); return nullptr; @@ -1048,7 +1048,7 @@ return nullptr; } - auto sType = type.dyn_cast(); + auto sType = dyn_cast(type); if (!sType) { emitError("elements literal must be a shaped type"); return nullptr; diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp --- a/mlir/lib/AsmParser/DialectSymbolParser.cpp +++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp @@ -260,7 +260,7 @@ }); // Ensure that the attribute has the same type as requested. - auto typedAttr = attr.dyn_cast_or_null(); + auto typedAttr = dyn_cast_or_null(attr); if (type && typedAttr && typedAttr.getType() != type) { emitError("attribute type different than expected: expected ") << type << ", but got " << typedAttr.getType(); diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -1333,7 +1333,7 @@ auto type = parseType(); if (!type) return failure(); - auto fnType = type.dyn_cast(); + auto fnType = dyn_cast(type); if (!fnType) return mlir::emitError(typeLoc, "expected function type"); @@ -2352,7 +2352,7 @@ if (!forwardRefPlaceholders.count(result)) detailOS << result.getOwner()->getName() << ": "; } else { - detailOS << "arg #" << frontValue.cast().getArgNumber() + detailOS << "arg #" << cast(frontValue).getArgNumber() << ": "; } diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -129,7 +129,7 @@ if (!elementType || parseToken(Token::greater, "expected '>' in complex type")) return nullptr; - if (!elementType.isa() && !elementType.isa()) + if (!isa(elementType) && !isa(elementType)) return emitError(elementTypeLoc, "invalid element type for complex"), nullptr; @@ -207,8 +207,8 @@ if (!attr) return failure(); - if (attr.isa()) { - layout = attr.cast(); + if (isa(attr)) { + layout = cast(attr); } else if (memorySpace) { return emitError("multiple memory spaces specified in memref type"); } else { @@ -383,7 +383,7 @@ Attribute encoding; if (consumeIf(Token::comma)) { encoding = parseAttribute(); - if (auto v = encoding.dyn_cast_or_null()) { + if (auto v = dyn_cast_or_null(encoding)) { if (failed(v.verifyEncoding(dimensions, elementType, [&] { return emitError(); }))) return nullptr; diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -783,7 +783,7 @@ Attribute baseResult; if (failed(parseAttribute(reader, baseResult))) return failure(); - if ((result = baseResult.dyn_cast())) + if ((result = dyn_cast(baseResult))) return success(); return reader.emitError("expected attribute of type: ", llvm::getTypeName(), ", but got: ", baseResult); diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -180,7 +180,7 @@ // have a registered dialect when it got created. We don't want to encode this // as the builtin OpaqueAttr, we want to encode it as if the dialect was // actually loaded. - if (OpaqueAttr opaqueAttr = attr.dyn_cast()) { + if (OpaqueAttr opaqueAttr = dyn_cast(attr)) { numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace()); return; } @@ -310,7 +310,7 @@ // registered dialect when it got created. We don't want to encode this as the // builtin OpaqueType, we want to encode it as if the dialect was actually // loaded. - if (OpaqueType opaqueType = type.dyn_cast()) { + if (OpaqueType opaqueType = dyn_cast(type)) { numbering->dialect = &numberDialect(opaqueType.getDialectNamespace()); return; } diff --git a/mlir/lib/CAPI/Dialect/PDL.cpp b/mlir/lib/CAPI/Dialect/PDL.cpp --- a/mlir/lib/CAPI/Dialect/PDL.cpp +++ b/mlir/lib/CAPI/Dialect/PDL.cpp @@ -21,7 +21,7 @@ //===---------------------------------------------------------------------===// bool mlirTypeIsAPDLType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } //===---------------------------------------------------------------------===// @@ -29,7 +29,7 @@ //===---------------------------------------------------------------------===// bool mlirTypeIsAPDLAttributeType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirPDLAttributeTypeGet(MlirContext ctx) { @@ -41,7 +41,7 @@ //===---------------------------------------------------------------------===// bool mlirTypeIsAPDLOperationType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirPDLOperationTypeGet(MlirContext ctx) { @@ -53,7 +53,7 @@ //===---------------------------------------------------------------------===// bool mlirTypeIsAPDLRangeType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirPDLRangeTypeGet(MlirType elementType) { @@ -61,7 +61,7 @@ } MlirType mlirPDLRangeTypeGetElementType(MlirType type) { - return wrap(unwrap(type).cast().getElementType()); + return wrap(cast(unwrap(type)).getElementType()); } //===---------------------------------------------------------------------===// @@ -69,7 +69,7 @@ //===---------------------------------------------------------------------===// bool mlirTypeIsAPDLTypeType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirPDLTypeTypeGet(MlirContext ctx) { @@ -81,7 +81,7 @@ //===---------------------------------------------------------------------===// bool mlirTypeIsAPDLValueType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirPDLValueTypeGet(MlirContext ctx) { diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp --- a/mlir/lib/CAPI/Dialect/Quant.cpp +++ b/mlir/lib/CAPI/Dialect/Quant.cpp @@ -20,7 +20,7 @@ //===---------------------------------------------------------------------===// bool mlirTypeIsAQuantizedType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } unsigned mlirQuantizedTypeGetSignedFlag() { @@ -40,39 +40,37 @@ } MlirType mlirQuantizedTypeGetExpressedType(MlirType type) { - return wrap(unwrap(type).cast().getExpressedType()); + return wrap(cast(unwrap(type)).getExpressedType()); } unsigned mlirQuantizedTypeGetFlags(MlirType type) { - return unwrap(type).cast().getFlags(); + return cast(unwrap(type)).getFlags(); } bool mlirQuantizedTypeIsSigned(MlirType type) { - return unwrap(type).cast().isSigned(); + return cast(unwrap(type)).isSigned(); } MlirType mlirQuantizedTypeGetStorageType(MlirType type) { - return wrap(unwrap(type).cast().getStorageType()); + return wrap(cast(unwrap(type)).getStorageType()); } int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type) { - return unwrap(type).cast().getStorageTypeMin(); + return cast(unwrap(type)).getStorageTypeMin(); } int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type) { - return unwrap(type).cast().getStorageTypeMax(); + return cast(unwrap(type)).getStorageTypeMax(); } unsigned mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type) { - return unwrap(type) - .cast() - .getStorageTypeIntegralWidth(); + return cast(unwrap(type)).getStorageTypeIntegralWidth(); } bool mlirQuantizedTypeIsCompatibleExpressedType(MlirType type, MlirType candidate) { - return unwrap(type).cast().isCompatibleExpressedType( - unwrap(candidate)); + return cast(unwrap(type)) + .isCompatibleExpressedType(unwrap(candidate)); } MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) { @@ -81,19 +79,19 @@ MlirType mlirQuantizedTypeCastFromStorageType(MlirType type, MlirType candidate) { - return wrap(unwrap(type).cast().castFromStorageType( - unwrap(candidate))); + return wrap(cast(unwrap(type)) + .castFromStorageType(unwrap(candidate))); } MlirType mlirQuantizedTypeCastToStorageType(MlirType type) { return wrap(quant::QuantizedType::castToStorageType( - unwrap(type).cast())); + cast(unwrap(type)))); } MlirType mlirQuantizedTypeCastFromExpressedType(MlirType type, MlirType candidate) { - return wrap(unwrap(type).cast().castFromExpressedType( - unwrap(candidate))); + return wrap(cast(unwrap(type)) + .castFromExpressedType(unwrap(candidate))); } MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) { @@ -102,9 +100,8 @@ MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type, MlirType candidate) { - return wrap( - unwrap(type).cast().castExpressedToStorageType( - unwrap(candidate))); + return wrap(cast(unwrap(type)) + .castExpressedToStorageType(unwrap(candidate))); } //===---------------------------------------------------------------------===// @@ -112,7 +109,7 @@ //===---------------------------------------------------------------------===// bool mlirTypeIsAAnyQuantizedType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType, @@ -128,7 +125,7 @@ //===---------------------------------------------------------------------===// bool mlirTypeIsAUniformQuantizedType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType, @@ -141,15 +138,15 @@ } double mlirUniformQuantizedTypeGetScale(MlirType type) { - return unwrap(type).cast().getScale(); + return cast(unwrap(type)).getScale(); } int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type) { - return unwrap(type).cast().getZeroPoint(); + return cast(unwrap(type)).getZeroPoint(); } bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) { - return unwrap(type).cast().isFixedPoint(); + return cast(unwrap(type)).isFixedPoint(); } //===---------------------------------------------------------------------===// @@ -157,7 +154,7 @@ //===---------------------------------------------------------------------===// bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirUniformQuantizedPerAxisTypeGet( @@ -172,33 +169,29 @@ } intptr_t mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type) { - return unwrap(type) - .cast() + return cast(unwrap(type)) .getScales() .size(); } double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, intptr_t pos) { - return unwrap(type) - .cast() + return cast(unwrap(type)) .getScales()[pos]; } int64_t mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type, intptr_t pos) { - return unwrap(type) - .cast() + return cast(unwrap(type)) .getZeroPoints()[pos]; } int32_t mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type) { - return unwrap(type) - .cast() + return cast(unwrap(type)) .getQuantizedDimension(); } bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) { - return unwrap(type).cast().isFixedPoint(); + return cast(unwrap(type)).isFixedPoint(); } //===---------------------------------------------------------------------===// @@ -206,7 +199,7 @@ //===---------------------------------------------------------------------===// bool mlirTypeIsACalibratedQuantizedType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min, @@ -216,9 +209,9 @@ } double mlirCalibratedQuantizedTypeGetMin(MlirType type) { - return unwrap(type).cast().getMin(); + return cast(unwrap(type)).getMin(); } double mlirCalibratedQuantizedTypeGetMax(MlirType type) { - return unwrap(type).cast().getMax(); + return cast(unwrap(type)).getMax(); } diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -42,7 +42,7 @@ "MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch"); bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { - return unwrap(attr).isa(); + return isa(unwrap(attr)); } MlirAttribute mlirSparseTensorEncodingAttrGet( @@ -60,29 +60,28 @@ } MlirAffineMap mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getDimOrdering()); + return wrap(cast(unwrap(attr)).getDimOrdering()); } MlirAffineMap mlirSparseTensorEncodingAttrGetHigherOrdering(MlirAttribute attr) { - return wrap( - unwrap(attr).cast().getHigherOrdering()); + return wrap(cast(unwrap(attr)).getHigherOrdering()); } intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) { - return unwrap(attr).cast().getLvlRank(); + return cast(unwrap(attr)).getLvlRank(); } MlirSparseTensorDimLevelType mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t lvl) { return static_cast( - unwrap(attr).cast().getLvlType(lvl)); + cast(unwrap(attr)).getLvlType(lvl)); } int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) { - return unwrap(attr).cast().getPosWidth(); + return cast(unwrap(attr)).getPosWidth(); } int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) { - return unwrap(attr).cast().getCrdWidth(); + return cast(unwrap(attr)).getCrdWidth(); } diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp --- a/mlir/lib/CAPI/Dialect/Transform.cpp +++ b/mlir/lib/CAPI/Dialect/Transform.cpp @@ -22,7 +22,7 @@ //===---------------------------------------------------------------------===// bool mlirTypeIsATransformAnyOpType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) { @@ -34,7 +34,7 @@ //===---------------------------------------------------------------------===// bool mlirTypeIsATransformOperationType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirTransformOperationTypeGet(MlirContext ctx, @@ -44,5 +44,5 @@ } MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) { - return wrap(unwrap(type).cast().getOperationName()); + return wrap(cast(unwrap(type)).getOperationName()); } diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp --- a/mlir/lib/CAPI/Interfaces/Interfaces.cpp +++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp @@ -56,7 +56,7 @@ (void)unwrapList(nOperands, operands, unwrappedOperands); DictionaryAttr attributeDict; if (!mlirAttributeIsNull(attributes)) - attributeDict = unwrap(attributes).cast(); + attributeDict = cast(unwrap(attributes)); // Create a vector of unique pointers to regions and make sure they are not // deleted when exiting the scope. This is a hack caused by C++ API expecting diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -48,7 +48,7 @@ Location loc = gpuOp.getLoc(); Value memref = adaptor.getMemref(); Value unconvertedMemref = gpuOp.getMemref(); - MemRefType memrefType = unconvertedMemref.getType().cast(); + MemRefType memrefType = cast(unconvertedMemref.getType()); if (chipset.majorVersion < 9) return gpuOp.emitOpError("Raw buffer ops require GCN or higher"); @@ -85,13 +85,13 @@ // so bitcast any floats to integers. Type llvmBufferValType = llvmWantedDataType; if (atomicCmpData) { - if (wantedDataType.isa()) + if (isa(wantedDataType)) return gpuOp.emitOpError("vector compare-and-swap does not exist"); - if (auto floatType = wantedDataType.dyn_cast()) + if (auto floatType = dyn_cast(wantedDataType)) llvmBufferValType = this->getTypeConverter()->convertType( rewriter.getIntegerType(floatType.getWidth())); } - if (auto dataVector = wantedDataType.dyn_cast()) { + if (auto dataVector = dyn_cast(wantedDataType)) { uint32_t elemBits = dataVector.getElementTypeBitWidth(); uint32_t totalBits = elemBits * dataVector.getNumElements(); if (totalBits > maxVectorOpWidth) @@ -312,7 +312,7 @@ static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter, Location loc, Value input) { Type inputType = input.getType(); - if (auto vectorType = inputType.dyn_cast()) { + if (auto vectorType = dyn_cast(inputType)) { if (!vectorType.getElementType().isInteger(8)) return input; int64_t numBytes = vectorType.getNumElements(); @@ -342,10 +342,10 @@ uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(), b = mfma.getBlocks(); Type sourceElem = mfma.getSourceA().getType(); - if (auto sourceType = sourceElem.dyn_cast()) + if (auto sourceType = dyn_cast(sourceElem)) sourceElem = sourceType.getElementType(); Type destElem = mfma.getDestC().getType(); - if (auto destType = destElem.dyn_cast()) + if (auto destType = dyn_cast(destElem)) destElem = destType.getElementType(); if (sourceElem.isF32() && destElem.isF32()) { @@ -406,7 +406,7 @@ return ROCDL::mfma_f32_16x16x8bf16::getOperationName(); } - if (sourceElem.isa() && destElem.isInteger(32)) { + if (isa(sourceElem) && destElem.isInteger(32)) { if (m == 32 && n == 32 && k == 4 && b == 2) return ROCDL::mfma_i32_32x32x4i8::getOperationName(); if (m == 16 && n == 16 && k == 4 && b == 4) @@ -435,7 +435,7 @@ // Known to be correct because there are no scalar f8 instructions and // because a length mismatch will have been caught by the verifier. Type sourceBElem = - mfma.getSourceB().getType().cast().getElementType(); + cast(mfma.getSourceB().getType()).getElementType(); if (m == 16 && n == 16 && k == 32 && b == 1) { if (sourceBElem.isFloat8E5M2FNUZ()) return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName(); @@ -453,7 +453,7 @@ if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset.minorVersion >= 0x40) { Type sourceBElem = - mfma.getSourceB().getType().cast().getElementType(); + cast(mfma.getSourceB().getType()).getElementType(); if (m == 16 && n == 16 && k == 32 && b == 1) { if (sourceBElem.isFloat8E5M2FNUZ()) return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName(); diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -226,7 +226,7 @@ Type resultType = std::get<1>(pair); std::optional reductionOp = arith::symbolizeAtomicRMWKind( - static_cast(reduction.cast().getInt())); + static_cast(cast(reduction).getInt())); assert(reductionOp && "Reduction operation cannot be of None Type"); arith::AtomicRMWKind reductionOpValue = *reductionOp; identityVals.push_back( @@ -246,7 +246,7 @@ // For each of the reduction operations get the respective mlir::Value. std::optional reductionOp = arith::symbolizeAtomicRMWKind( - reductions[i].cast().getInt()); + cast(reductions[i]).getInt()); assert(reductionOp && "Reduction Operation cannot be of None Type"); arith::AtomicRMWKind reductionOpValue = *reductionOp; rewriter.setInsertionPoint(&parOp.getBody()->back()); diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -210,7 +210,7 @@ // Handle the scalar and 1D vector cases. Type operandType = adaptor.getIn().getType(); - if (!operandType.isa()) { + if (!isa(operandType)) { Type targetType = this->typeConverter->convertType(resultType); if (targetBits < sourceBits) rewriter.replaceOpWithNewOp(op, targetType, @@ -220,7 +220,7 @@ return success(); } - if (!resultType.isa()) + if (!isa(resultType)) return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( @@ -255,7 +255,7 @@ Location loc = op.getLoc(); // Handle the scalar and 1D vector cases. - if (!operandType.isa()) { + if (!isa(operandType)) { Type newOverflowType = typeConverter->convertType(overflowResultType); Type structType = LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType}); @@ -269,7 +269,7 @@ return success(); } - if (!sumResultType.isa()) + if (!isa(sumResultType)) return rewriter.notifyMatchFailure(loc, "expected vector result types"); return rewriter.notifyMatchFailure(loc, @@ -295,16 +295,16 @@ // matching extended multiplication intrinsic, perform regular multiplication // on operands zero-extended to i(2*N) bits, and truncate the results back to // iN types. - if (!resultType.isa()) { + if (!isa(resultType)) { // Shift amount necessary to extract the high bits from widened result. TypedAttr shiftValAttr; - if (auto intTy = resultType.dyn_cast()) { + if (auto intTy = dyn_cast(resultType)) { unsigned resultBitwidth = intTy.getWidth(); auto attrTy = rewriter.getIntegerType(resultBitwidth * 2); shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth); } else { - auto vecTy = resultType.cast(); + auto vecTy = cast(resultType); unsigned resultBitwidth = vecTy.getElementTypeBitWidth(); auto attrTy = VectorType::get( vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2)); @@ -330,7 +330,7 @@ return success(); } - if (!resultType.isa()) + if (!isa(resultType)) return rewriter.notifyMatchFailure(op, "expected vector result type"); return rewriter.notifyMatchFailure(op, @@ -355,7 +355,7 @@ Type resultType = op.getResult().getType(); // Handle the scalar and 1D vector cases. - if (!operandType.isa()) { + if (!isa(operandType)) { rewriter.replaceOpWithNewOp( op, typeConverter->convertType(resultType), convertCmpPredicate(op.getPredicate()), @@ -363,7 +363,7 @@ return success(); } - if (!resultType.isa()) + if (!isa(resultType)) return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( @@ -389,7 +389,7 @@ Type resultType = op.getResult().getType(); // Handle the scalar and 1D vector cases. - if (!operandType.isa()) { + if (!isa(operandType)) { rewriter.replaceOpWithNewOp( op, typeConverter->convertType(resultType), convertCmpPredicate(op.getPredicate()), @@ -397,7 +397,7 @@ return success(); } - if (!resultType.isa()) + if (!isa(resultType)) return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -261,9 +261,9 @@ /// Converts the given `srcAttr` into a boolean attribute if it holds an /// integral value. Returns null attribute if conversion fails. static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { - if (auto boolAttr = srcAttr.dyn_cast()) + if (auto boolAttr = dyn_cast(srcAttr)) return boolAttr; - if (auto intAttr = srcAttr.dyn_cast()) + if (auto intAttr = dyn_cast(srcAttr)) return builder.getBoolAttr(intAttr.getValue().getBoolValue()); return {}; } @@ -324,7 +324,7 @@ if (type.isInteger(1)) return true; - if (auto vecType = type.dyn_cast()) + if (auto vecType = dyn_cast(type)) return vecType.getElementType().isInteger(1); return false; @@ -337,7 +337,7 @@ unsigned bw = 0; if (type.isIntOrFloat()) bw = type.getIntOrFloatBitWidth(); - else if (auto vecType = type.dyn_cast()) + else if (auto vecType = dyn_cast(type)) bw = vecType.getElementTypeBitWidth() * vecType.getNumElements(); return bw; }; @@ -369,18 +369,18 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite( arith::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto srcType = constOp.getType().dyn_cast(); + auto srcType = dyn_cast(constOp.getType()); if (!srcType || srcType.getNumElements() == 1) return failure(); // arith.constant should only have vector or tenor types. - assert((srcType.isa())); + assert((isa(srcType))); Type dstType = getTypeConverter()->convertType(srcType); if (!dstType) return failure(); - auto dstElementsAttr = constOp.getValue().dyn_cast(); + auto dstElementsAttr = dyn_cast(constOp.getValue()); if (!dstElementsAttr) return failure(); @@ -388,7 +388,7 @@ // If the composite type has more than one dimensions, perform linearization. if (srcType.getRank() > 1) { - if (srcType.isa()) { + if (isa(srcType)) { dstAttrType = RankedTensorType::get(srcType.getNumElements(), srcType.getElementType()); dstElementsAttr = dstElementsAttr.reshape(dstAttrType); @@ -402,19 +402,19 @@ Type dstElemType; // Tensor types are converted to SPIR-V array types; vector types are // converted to SPIR-V vector/array types. - if (auto arrayType = dstType.dyn_cast()) + if (auto arrayType = dyn_cast(dstType)) dstElemType = arrayType.getElementType(); else - dstElemType = dstType.cast().getElementType(); + dstElemType = cast(dstType).getElementType(); // If the source and destination element types are different, perform // attribute conversion. if (srcElemType != dstElemType) { SmallVector elements; - if (srcElemType.isa()) { + if (isa(srcElemType)) { for (FloatAttr srcAttr : dstElementsAttr.getValues()) { FloatAttr dstAttr = - convertFloatAttr(srcAttr, dstElemType.cast(), rewriter); + convertFloatAttr(srcAttr, cast(dstElemType), rewriter); if (!dstAttr) return failure(); elements.push_back(dstAttr); @@ -424,7 +424,7 @@ } else { for (IntegerAttr srcAttr : dstElementsAttr.getValues()) { IntegerAttr dstAttr = convertIntegerAttr( - srcAttr, dstElemType.cast(), rewriter); + srcAttr, cast(dstElemType), rewriter); if (!dstAttr) return failure(); elements.push_back(dstAttr); @@ -435,7 +435,7 @@ // attributes; element attributes only works with builtin types. So we need // to prepare another converted builtin types for the destination elements // attribute. - if (dstAttrType.isa()) + if (isa(dstAttrType)) dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); else dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); @@ -456,7 +456,7 @@ arith::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Type srcType = constOp.getType(); - if (auto shapedType = srcType.dyn_cast()) { + if (auto shapedType = dyn_cast(srcType)) { if (shapedType.getNumElements() != 1) return failure(); srcType = shapedType.getElementType(); @@ -465,7 +465,7 @@ return failure(); Attribute cstAttr = constOp.getValue(); - if (auto elementsAttr = cstAttr.dyn_cast()) + if (auto elementsAttr = dyn_cast(cstAttr)) cstAttr = elementsAttr.getSplatValue(); Type dstType = getTypeConverter()->convertType(srcType); @@ -473,14 +473,14 @@ return failure(); // Floating-point types. - if (srcType.isa()) { - auto srcAttr = cstAttr.cast(); + if (isa(srcType)) { + auto srcAttr = cast(cstAttr); auto dstAttr = srcAttr; // Floating-point types not supported in the target environment are all // converted to float type. if (srcType != dstType) { - dstAttr = convertFloatAttr(srcAttr, dstType.cast(), rewriter); + dstAttr = convertFloatAttr(srcAttr, cast(dstType), rewriter); if (!dstAttr) return failure(); } @@ -502,9 +502,9 @@ // IndexType or IntegerType. Index values are converted to 32-bit integer // values when converting to SPIR-V. - auto srcAttr = cstAttr.cast(); + auto srcAttr = cast(cstAttr); IntegerAttr dstAttr = - convertIntegerAttr(srcAttr, dstType.cast(), rewriter); + convertIntegerAttr(srcAttr, cast(dstType), rewriter); if (!dstAttr) return failure(); rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); @@ -678,12 +678,12 @@ return getTypeConversionFailure(rewriter, op); Value allOnes; - if (auto intTy = dstType.dyn_cast()) { + if (auto intTy = dyn_cast(dstType)) { unsigned componentBitwidth = intTy.getWidth(); allOnes = rewriter.create( loc, intTy, rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth))); - } else if (auto vectorTy = dstType.dyn_cast()) { + } else if (auto vectorTy = dyn_cast(dstType)) { unsigned componentBitwidth = vectorTy.getElementTypeBitWidth(); allOnes = rewriter.create( loc, vectorTy, @@ -810,7 +810,7 @@ // There are no direct corresponding instructions in SPIR-V for such cases. // Extend them to 32-bit and do comparision then. Type type = rewriter.getI32Type(); - if (auto vectorType = dstType.dyn_cast()) + if (auto vectorType = dyn_cast(dstType)) type = VectorType::get(vectorType.getShape(), type); Value extLhs = rewriter.create(op.getLoc(), type, adaptor.getLhs()); diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp --- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp +++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp @@ -33,8 +33,8 @@ /// arm.neon.intr.sdot LogicalResult matchAndRewrite(Sdot2dOp op, PatternRewriter &rewriter) const override { - Type elemType = op.getB().getType().cast().getElementType(); - int length = op.getB().getType().cast().getShape()[0] * + Type elemType = cast(op.getB().getType()).getElementType(); + int length = cast(op.getB().getType()).getShape()[0] * Sdot2dOp::kReductionSize; VectorType flattenedVectorType = VectorType::get({length}, elemType); Value b2d = op.getB(); diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -366,12 +366,12 @@ static std::optional convertAsyncTypes(Type type, bool useOpaquePointers) { - if (type.isa()) + if (isa(type)) return AsyncAPI::opaquePointerType(type.getContext(), useOpaquePointers); - if (type.isa()) + if (isa(type)) return AsyncAPI::tokenType(type.getContext()); - if (type.isa()) + if (isa(type)) return AsyncAPI::opaquePointerType(type.getContext(), useOpaquePointers); return std::nullopt; @@ -656,14 +656,14 @@ Type resultType = op->getResultTypes()[0]; // Tokens creation maps to a simple function call. - if (resultType.isa()) { + if (isa(resultType)) { rewriter.replaceOpWithNewOp( op, kCreateToken, converter->convertType(resultType)); return success(); } // To create a value we need to compute the storage requirement. - if (auto value = resultType.dyn_cast()) { + if (auto value = dyn_cast(resultType)) { // Returns the size requirements for the async value storage. auto sizeOf = [&](ValueType valueType) -> Value { auto loc = op->getLoc(); @@ -994,7 +994,7 @@ matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Currently we can only add tokens to the group. - if (!op.getOperand().getType().isa()) + if (!isa(op.getOperand().getType())) return rewriter.notifyMatchFailure(op, "only token type is supported"); // Replace with a runtime API function call. diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp --- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp +++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp @@ -41,11 +41,11 @@ ConversionPatternRewriter &rewriter) const override { // Check for unranked memref types which are currently not supported. Type type = op.getType(); - if (type.isa()) { + if (isa(type)) { return rewriter.notifyMatchFailure( op, "UnrankedMemRefType is not supported."); } - MemRefType memrefType = type.cast(); + MemRefType memrefType = cast(type); MemRefLayoutAttrInterface layout; auto allocType = MemRefType::get(memrefType.getShape(), memrefType.getElementType(), diff --git a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp --- a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp +++ b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp @@ -26,9 +26,9 @@ // result type. struct ComplexTypeResolver { std::optional operator()(Type type) const { - auto complexType = type.cast(); + auto complexType = cast(type); auto elementType = complexType.getElementType(); - if (!elementType.isa()) + if (!isa(elementType)) return {}; return elementType.getIntOrFloatBitWidth() == 64; @@ -39,8 +39,8 @@ // type. struct FloatTypeResolver { std::optional operator()(Type type) const { - auto elementType = type.cast(); - if (!elementType.isa()) + auto elementType = cast(type); + if (!isa(elementType)) return {}; return elementType.getIntOrFloatBitWidth() == 64; diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -57,7 +57,7 @@ ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - auto type = op.getType().cast(); + auto type = cast(op.getType()); Type elementType = type.getElementType(); Value lhs = adaptor.getLhs(); @@ -102,10 +102,7 @@ matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto type = adaptor.getLhs() - .getType() - .template cast() - .getElementType(); + auto type = cast(adaptor.getLhs().getType()).getElementType(); Value realLhs = rewriter.create(loc, type, adaptor.getLhs()); Value imagLhs = rewriter.create(loc, type, adaptor.getLhs()); @@ -132,8 +129,8 @@ LogicalResult matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto type = adaptor.getLhs().getType().template cast(); - auto elementType = type.getElementType().template cast(); + auto type = cast(adaptor.getLhs().getType()); + auto elementType = cast(type.getElementType()); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value realLhs = b.create(elementType, adaptor.getLhs()); @@ -160,8 +157,8 @@ matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto type = adaptor.getComplex().getType().template cast(); - auto elementType = type.getElementType().template cast(); + auto type = cast(adaptor.getComplex().getType()); + auto elementType = cast(type.getElementType()); Value real = rewriter.create(loc, elementType, adaptor.getComplex()); @@ -222,8 +219,8 @@ matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto type = adaptor.getLhs().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getLhs().getType()); + auto elementType = cast(type.getElementType()); Value lhsReal = rewriter.create(loc, elementType, adaptor.getLhs()); @@ -441,8 +438,8 @@ matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto type = adaptor.getComplex().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getComplex().getType()); + auto elementType = cast(type.getElementType()); Value real = rewriter.create(loc, elementType, adaptor.getComplex()); @@ -466,8 +463,8 @@ LogicalResult matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto type = adaptor.getComplex().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getComplex().getType()); + auto elementType = cast(type.getElementType()); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value exp = b.create(adaptor.getComplex()); @@ -490,8 +487,8 @@ LogicalResult matchAndRewrite(complex::LogOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto type = adaptor.getComplex().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getComplex().getType()); + auto elementType = cast(type.getElementType()); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value abs = b.create(elementType, adaptor.getComplex()); @@ -511,8 +508,8 @@ LogicalResult matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto type = adaptor.getComplex().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getComplex().getType()); + auto elementType = cast(type.getElementType()); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value real = b.create(elementType, adaptor.getComplex()); @@ -550,8 +547,8 @@ matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - auto type = adaptor.getLhs().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getLhs().getType()); + auto elementType = cast(type.getElementType()); Value lhsReal = b.create(elementType, adaptor.getLhs()); Value lhsRealAbs = b.create(lhsReal); @@ -727,8 +724,8 @@ matchAndRewrite(complex::NegOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto type = adaptor.getComplex().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getComplex().getType()); + auto elementType = cast(type.getElementType()); Value real = rewriter.create(loc, elementType, adaptor.getComplex()); @@ -773,7 +770,7 @@ ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - auto type = op.getType().cast(); + auto type = cast(op.getType()); Type elementType = type.getElementType(); Value arg = adaptor.getComplex(); @@ -837,8 +834,8 @@ LogicalResult matchAndRewrite(complex::SignOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto type = adaptor.getComplex().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getComplex().getType()); + auto elementType = cast(type.getElementType()); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value real = b.create(elementType, adaptor.getComplex()); @@ -881,8 +878,8 @@ matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto type = adaptor.getComplex().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getComplex().getType()); + auto elementType = cast(type.getElementType()); // The hyperbolic tangent for complex number can be calculated as follows. // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y)) @@ -913,8 +910,8 @@ matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto type = adaptor.getComplex().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getComplex().getType()); + auto elementType = cast(type.getElementType()); Value real = rewriter.create(loc, elementType, adaptor.getComplex()); Value imag = @@ -933,7 +930,7 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder, ComplexType type, Value a, Value b, Value c, Value d) { - auto elementType = type.getElementType().cast(); + auto elementType = cast(type.getElementType()); // Compute (a*a+b*b)^(0.5c). Value aaPbb = builder.create( @@ -995,8 +992,8 @@ matchAndRewrite(complex::PowOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter); - auto type = adaptor.getLhs().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getLhs().getType()); + auto elementType = cast(type.getElementType()); Value a = builder.create(elementType, adaptor.getLhs()); Value b = builder.create(elementType, adaptor.getLhs()); @@ -1015,8 +1012,8 @@ matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter); - auto type = adaptor.getComplex().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getComplex().getType()); + auto elementType = cast(type.getElementType()); Value a = builder.create(elementType, adaptor.getComplex()); Value b = builder.create(elementType, adaptor.getComplex()); diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -144,13 +144,13 @@ size_t argOffset = resultStructType ? 1 : 0; for (auto [index, argType] : llvm::enumerate(type.getInputs())) { Value arg = wrapperFuncOp.getArgument(index + argOffset); - if (auto memrefType = argType.dyn_cast()) { + if (auto memrefType = dyn_cast(argType)) { Value loaded = rewriter.create( loc, typeConverter.convertType(memrefType), arg); MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args); continue; } - if (argType.isa()) { + if (isa(argType)) { Value loaded = rewriter.create( loc, typeConverter.convertType(argType), arg); UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args); @@ -218,8 +218,7 @@ if (resultStructType) { // Allocate the struct on the stack and pass the pointer. - Type resultType = - wrapperType.cast().getParamType(0); + Type resultType = cast(wrapperType).getParamType(0); Value one = builder.create( loc, typeConverter.convertType(builder.getIndexType()), builder.getIntegerAttr(builder.getIndexType(), 1)); @@ -233,8 +232,8 @@ for (Type input : type.getInputs()) { Value arg; int numToDrop = 1; - auto memRefType = input.dyn_cast(); - auto unrankedMemRefType = input.dyn_cast(); + auto memRefType = dyn_cast(input); + auto unrankedMemRefType = dyn_cast(input); if (memRefType || unrankedMemRefType) { numToDrop = memRefType ? MemRefDescriptor::getNumUnpackedValues(memRefType) @@ -301,9 +300,9 @@ // Unranked memrefs are not supported in the bare pointer calling // convention. We should have bailed out before in the presence of // unranked memrefs. - assert(!argTy.isa() && + assert(!isa(argTy) && "Unranked memref is not supported"); - auto memrefTy = argTy.dyn_cast(); + auto memrefTy = dyn_cast(argTy); if (!memrefTy) continue; @@ -360,18 +359,18 @@ } if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) { SmallVector newArgAttrs( - llvmType.cast().getNumParams()); + cast(llvmType).getNumParams()); for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { // Some LLVM IR attribute have a type attached to them. During FuncOp -> // LLVMFuncOp conversion these types may have changed. Account for that // change by converting attributes' types as well. SmallVector convertedAttrs; - auto attrsDict = argAttrDicts[i].cast(); + auto attrsDict = cast(argAttrDicts[i]); convertedAttrs.reserve(attrsDict.size()); for (const NamedAttribute &attr : attrsDict) { const auto convert = [&](const NamedAttribute &attr) { return TypeAttr::get(getTypeConverter()->convertType( - attr.getValue().cast().getValue())); + cast(attr.getValue()).getValue())); }; if (attr.getName().getValue() == LLVM::LLVMDialect::getByValAttrName()) { @@ -418,7 +417,7 @@ LLVM::Linkage linkage = LLVM::Linkage::External; if (funcOp->hasAttr(linkageAttrName)) { auto attr = - funcOp->getAttr(linkageAttrName).dyn_cast(); + dyn_cast(funcOp->getAttr(linkageAttrName)); if (!attr) { funcOp->emitError() << "Contains " << linkageAttrName << " attribute not of type LLVM::LinkageAttr"; @@ -545,7 +544,7 @@ if (useBarePtrCallConv) { for (auto it : callOp->getOperands()) { Type operandType = it.getType(); - if (operandType.isa()) { + if (isa(operandType)) { // Unranked memref is not supported in the bare pointer calling // convention. return failure(); @@ -669,11 +668,11 @@ for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { Type oldTy = std::get<0>(it).getType(); Value newOperand = std::get<1>(it); - if (oldTy.isa() && getTypeConverter()->canConvertToBarePtr( - oldTy.cast())) { + if (isa(oldTy) && getTypeConverter()->canConvertToBarePtr( + cast(oldTy))) { MemRefDescriptor memrefDesc(newOperand); newOperand = memrefDesc.allocatedPtr(rewriter, loc); - } else if (oldTy.isa()) { + } else if (isa(oldTy)) { // Unranked memref is not supported in the bare pointer calling // convention. return failure(); diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -26,22 +26,20 @@ for (const auto &en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) { BlockArgument attribution = en.value(); - auto type = attribution.getType().dyn_cast(); + auto type = dyn_cast(attribution.getType()); assert(type && type.hasStaticShape() && "unexpected type in attribution"); uint64_t numElements = type.getNumElements(); auto elementType = - typeConverter->convertType(type.getElementType()).template cast(); + cast(typeConverter->convertType(type.getElementType())); auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements); std::string name = std::string( llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index())); uint64_t alignment = 0; if (auto alignAttr = - gpuFuncOp - .getWorkgroupAttributionAttr( - en.index(), LLVM::LLVMDialect::getAlignAttrName()) - .dyn_cast_or_null()) + dyn_cast_or_null(gpuFuncOp.getWorkgroupAttributionAttr( + en.index(), LLVM::LLVMDialect::getAlignAttrName()))) alignment = alignAttr.getInt(); auto globalOp = rewriter.create( gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, @@ -100,7 +98,7 @@ global.getAddrSpace()), global.getSymNameAttr()); auto elementType = - global.getType().cast().getElementType(); + cast(global.getType()).getElementType(); Value memory = rewriter.create( loc, getTypeConverter()->getPointerType(elementType, @@ -112,7 +110,7 @@ // otherwise necessary given that memref sizes are fixed, but we can try // and canonicalize that away later. Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()]; - auto type = attribution.getType().cast(); + auto type = cast(attribution.getType()); auto descr = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), type, memory); signatureConversion.remapInput(numProperArguments + en.index(), descr); @@ -123,7 +121,7 @@ auto int64Ty = IntegerType::get(rewriter.getContext(), 64); for (const auto &en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) { Value attribution = en.value(); - auto type = attribution.getType().cast(); + auto type = cast(attribution.getType()); assert(type && type.hasStaticShape() && "unexpected type in attribution"); // Explicitly drop memory space when lowering private memory @@ -136,10 +134,8 @@ gpuFuncOp.getLoc(), int64Ty, type.getNumElements()); uint64_t alignment = 0; if (auto alignAttr = - gpuFuncOp - .getPrivateAttributionAttr( - en.index(), LLVM::LLVMDialect::getAlignAttrName()) - .dyn_cast_or_null()) + dyn_cast_or_null(gpuFuncOp.getPrivateAttributionAttr( + en.index(), LLVM::LLVMDialect::getAlignAttrName()))) alignment = alignAttr.getInt(); Value allocated = rewriter.create( gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment); @@ -164,7 +160,7 @@ OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front()); for (const auto &en : llvm::enumerate(gpuFuncOp.getArgumentTypes())) { - auto memrefTy = en.value().dyn_cast(); + auto memrefTy = dyn_cast(en.value()); if (!memrefTy) continue; assert(memrefTy.hasStaticShape() && @@ -302,7 +298,7 @@ rewriter.create(loc, llvmI32, numArgsThisCall)); for (size_t i = group; i < bound; ++i) { Value arg = adaptor.getArgs()[i]; - if (auto floatType = arg.getType().dyn_cast()) { + if (auto floatType = dyn_cast(arg.getType())) { if (!floatType.isF64()) arg = rewriter.create( loc, typeConverter->convertType(rewriter.getF64Type()), arg); @@ -428,7 +424,7 @@ Type type = arg.getType(); Value promotedArg = arg; assert(type.isIntOrFloat()); - if (type.isa()) { + if (isa(type)) { type = rewriter.getF64Type(); promotedArg = rewriter.create(loc, type, arg); } @@ -462,14 +458,14 @@ LLVMTypeConverter &converter) { TypeRange operandTypes(operands); if (llvm::none_of(operandTypes, - [](Type type) { return type.isa(); })) { + [](Type type) { return isa(type); })) { return rewriter.notifyMatchFailure(op, "expected vector operand"); } if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0) return rewriter.notifyMatchFailure(op, "expected no region/successor"); if (op->getNumResults() != 1) return rewriter.notifyMatchFailure(op, "expected single result"); - VectorType vectorType = op->getResult(0).getType().dyn_cast(); + VectorType vectorType = dyn_cast(op->getResult(0).getType()); if (!vectorType) return rewriter.notifyMatchFailure(op, "expected vector result"); @@ -482,7 +478,7 @@ for (int64_t i = 0; i < vectorType.getNumElements(); ++i) { Value index = rewriter.create(loc, indexType, i); auto extractElement = [&](Value operand) -> Value { - if (!operand.getType().isa()) + if (!isa(operand.getType())) return operand; return rewriter.create(loc, operand, index); }; diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -454,7 +454,7 @@ Location loc = op->getLoc(); auto memRefType = hostRegisterOp.getValue().getType(); - auto elementType = memRefType.cast().getElementType(); + auto elementType = cast(memRefType).getElementType(); auto elementSize = getSizeInBytes(loc, elementType, rewriter); auto arguments = getTypeConverter()->promoteOperands( @@ -476,7 +476,7 @@ Location loc = op->getLoc(); auto memRefType = hostUnregisterOp.getValue().getType(); - auto elementType = memRefType.cast().getElementType(); + auto elementType = cast(memRefType).getElementType(); auto elementSize = getSizeInBytes(loc, elementType, rewriter); auto arguments = getTypeConverter()->promoteOperands( @@ -555,7 +555,7 @@ } static bool isGpuAsyncTokenType(Value value) { - return value.getType().isa(); + return isa(value.getType()); } // Converts !gpu.async.token operands of `async.yield` to runtime calls. The @@ -591,7 +591,7 @@ // Returns whether `value` is the result of an LLVM::CallOp to `functionName`. static bool isDefinedByCallTo(Value value, StringRef functionName) { - assert(value.getType().isa()); + assert(isa(value.getType())); if (auto defOp = value.getDefiningOp()) return defOp.getCallee()->equals(functionName); return false; @@ -862,7 +862,7 @@ LLVM::LLVMPointerType destinationType, Value sourcePtr, LLVMTypeConverter &typeConverter) { - auto sourceTy = sourcePtr.getType().cast(); + auto sourceTy = cast(sourcePtr.getType()); if (destinationType.getAddressSpace() != sourceTy.getAddressSpace()) sourcePtr = rewriter.create( loc, @@ -879,7 +879,7 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( gpu::MemcpyOp memcpyOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto memRefType = memcpyOp.getSrc().getType().cast(); + auto memRefType = cast(memcpyOp.getSrc().getType()); if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) || !isConvertibleAndHasIdentityMaps(memRefType) || @@ -919,7 +919,7 @@ LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite( gpu::MemsetOp memsetOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto memRefType = memsetOp.getDst().getType().cast(); + auto memRefType = cast(memsetOp.getDst().getType()); if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) || !isConvertibleAndHasIdentityMaps(memRefType) || diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -91,7 +91,7 @@ ? NVVM::MMALayout::col : NVVM::MMALayout::row; gpu::MMAMatrixType retType = - subgroupMmaLoadMatrixOp.getRes().getType().cast(); + cast(subgroupMmaLoadMatrixOp.getRes().getType()); ArrayRef retTypeShape = retType.getShape(); int64_t m = 0; int64_t n = 0; @@ -122,8 +122,7 @@ // Create nvvm.mma_load op according to the operand types. Value dataPtr = getStridedElementPtr( - loc, - subgroupMmaLoadMatrixOp.getSrcMemref().getType().cast(), + loc, cast(subgroupMmaLoadMatrixOp.getSrcMemref().getType()), adaptor.getSrcMemref(), adaptor.getIndices(), rewriter); Value leadingDim = rewriter.create( @@ -158,7 +157,7 @@ // Get the shape of the MMAMatrix type being stored. The shape will // choose which intrinsic this op will be lowered to. gpu::MMAMatrixType srcType = - subgroupMmaStoreMatrixOp.getSrc().getType().cast(); + cast(subgroupMmaStoreMatrixOp.getSrc().getType()); ArrayRef srcTypeShape = srcType.getShape(); NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose() ? NVVM::MMALayout::col @@ -170,7 +169,7 @@ if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0) return rewriter.notifyMatchFailure(op, kInvalidCaseStr); - auto matrixType = adaptor.getSrc().getType().cast(); + auto matrixType = cast(adaptor.getSrc().getType()); for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) { Value toUse = rewriter.create(loc, adaptor.getSrc(), i); @@ -179,7 +178,7 @@ Value dataPtr = getStridedElementPtr( loc, - subgroupMmaStoreMatrixOp.getDstMemref().getType().cast(), + cast(subgroupMmaStoreMatrixOp.getDstMemref().getType()), adaptor.getDstMemref(), adaptor.getIndices(), rewriter); Value leadingDim = rewriter.create( loc, rewriter.getI32Type(), @@ -214,7 +213,7 @@ SmallVector unpackedOps; auto unpackOp = [&](Value operand) { - auto structType = operand.getType().cast(); + auto structType = cast(operand.getType()); for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) { Value toUse = rewriter.create(loc, operand, i); unpackedOps.push_back(toUse); @@ -224,10 +223,10 @@ // Get the shapes of the MMAMatrix type being used. The shapes will // choose which intrinsic this op will be lowered to. gpu::MMAMatrixType aType = - subgroupMmaComputeOp.getOpA().getType().cast(); + cast(subgroupMmaComputeOp.getOpA().getType()); ArrayRef aTypeShape = aType.getShape(); gpu::MMAMatrixType cType = - subgroupMmaComputeOp.getOpC().getType().cast(); + cast(subgroupMmaComputeOp.getOpC().getType()); ArrayRef cTypeShape = cType.getShape(); int64_t m = cTypeShape[0]; int64_t n = cTypeShape[1]; @@ -245,7 +244,7 @@ return rewriter.notifyMatchFailure(op, kInvalidCaseStr); NVVM::MMATypes bElementType = getElementType( - subgroupMmaComputeOp.getOpB().getType().cast()); + cast(subgroupMmaComputeOp.getOpB().getType())); if (bElementType != sourceType) return rewriter.notifyMatchFailure( op, "WMMA compute op input matrix element types must match."); @@ -277,9 +276,9 @@ Location loc = subgroupMmaConstantOp.getLoc(); Value cst = adaptor.getOperands()[0]; LLVM::LLVMStructType type = convertMMAToLLVMType( - subgroupMmaConstantOp.getType().cast()); + cast(subgroupMmaConstantOp.getType())); // If the element type is a vector create a vector from the operand. - if (auto vecType = type.getBody()[0].dyn_cast()) { + if (auto vecType = dyn_cast(type.getBody()[0])) { Value vecCst = rewriter.create(loc, vecType); for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) { Value idx = rewriter.create( @@ -301,9 +300,9 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, Value rhs, bool isMin) { - auto floatType = getElementTypeOrSelf(lhs.getType()).cast(); + auto floatType = cast(getElementTypeOrSelf(lhs.getType())); Type i1Type = builder.getI1Type(); - if (auto vecType = lhs.getType().dyn_cast()) + if (auto vecType = dyn_cast(lhs.getType())) i1Type = VectorType::get(vecType.getShape(), i1Type); Value cmp = builder.create( loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, @@ -355,7 +354,7 @@ Location loc = subgroupMmaElementwiseOp.getLoc(); size_t numOperands = adaptor.getOperands().size(); LLVM::LLVMStructType destType = convertMMAToLLVMType( - subgroupMmaElementwiseOp.getType().cast()); + cast(subgroupMmaElementwiseOp.getType())); Value matrixStruct = rewriter.create(loc, destType); for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) { SmallVector extractedOperands; diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -54,7 +54,7 @@ static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) { bool canBeBare = true; for (Type type : func.getArgumentTypes()) - if (auto memrefTy = type.dyn_cast()) + if (auto memrefTy = dyn_cast(type)) canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy); return canBeBare; } @@ -166,9 +166,8 @@ // Manually rewrite known block size attributes so the LLVMIR translation // infrastructure can pick them up. m.walk([ctx](LLVM::LLVMFuncOp op) { - if (auto blockSizes = - op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()) - .dyn_cast_or_null()) { + if (auto blockSizes = dyn_cast_or_null( + op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) { op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(), blockSizes); // Also set up the rocdl.flat_work_group_size attribute to prevent diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -495,9 +495,9 @@ Type type = arg.getType(); using MembptrT = FuncT OpHandler::*; MembptrT handlerPtr; - if (type.isa()) { + if (isa(type)) { handlerPtr = &OpHandler::floatFunc; - } else if (type.isa()) { + } else if (isa(type)) { handlerPtr = &OpHandler::intFunc; } else { return std::nullopt; diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp @@ -81,9 +81,9 @@ ConversionPatternRewriter &rewriter) const override { Location loc = subgroupMmaLoadMatrixOp->getLoc(); gpu::MMAMatrixType retType = - subgroupMmaLoadMatrixOp.getRes().getType().cast(); + cast(subgroupMmaLoadMatrixOp.getRes().getType()); auto memrefType = - subgroupMmaLoadMatrixOp.getSrcMemref().getType().cast(); + cast(subgroupMmaLoadMatrixOp.getSrcMemref().getType()); Value bufferPtr = spirv::getElementPtr( *getTypeConverter(), memrefType, adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter); @@ -114,7 +114,7 @@ ConversionPatternRewriter &rewriter) const override { Location loc = subgroupMmaStoreMatrixOp->getLoc(); auto memrefType = - subgroupMmaStoreMatrixOp.getDstMemref().getType().cast(); + cast(subgroupMmaStoreMatrixOp.getDstMemref().getType()); Value bufferPtr = spirv::getElementPtr( *getTypeConverter(), memrefType, adaptor.getDstMemref(), adaptor.getIndices(), loc, rewriter); @@ -161,7 +161,7 @@ ConversionPatternRewriter &rewriter) const override { Value cst = adaptor.getOperands()[0]; auto coopType = convertMMAToSPIRVType( - subgroupMmaConstantMatrixOp.getType().cast()); + cast(subgroupMmaConstantMatrixOp.getType())); rewriter.replaceOpWithNewOp( subgroupMmaConstantMatrixOp, coopType, cst); return success(); @@ -180,11 +180,11 @@ ConversionPatternRewriter &rewriter) const override { // All operands should be of cooperative matrix types. for (Value operand : adaptor.getOperands()) { - if (!operand.getType().isa()) + if (!isa(operand.getType())) return failure(); } auto coopType = convertMMAToSPIRVType( - elementwiseOp.getType().cast()); + cast(elementwiseOp.getType())); return success(createElementwiseOp(rewriter, elementwiseOp, coopType, adaptor.getOperands())); } @@ -204,7 +204,7 @@ return failure(); // All operands should be of cooperative matrix types. for (Value operand : adaptor.getOperands()) { - if (!operand.getType().isa()) + if (!isa(operand.getType())) return failure(); } @@ -236,7 +236,7 @@ scalar = cc.getConstituents().front(); auto coopType = convertMMAToSPIRVType( - elementwiseOp.getType().cast()); + cast(elementwiseOp.getType())); rewriter.replaceOpWithNewOp( elementwiseOp, coopType, ValueRange{matrix, scalar}); return success(); diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp @@ -61,7 +61,7 @@ /// Checks where the given type is supported by Vulkan runtime. bool isSupportedType(Type type) { - if (auto memRefType = type.dyn_cast_or_null()) { + if (auto memRefType = dyn_cast_or_null(type)) { auto elementType = memRefType.getElementType(); return memRefType.hasRank() && (memRefType.getRank() >= 1 && memRefType.getRank() <= 3) && @@ -197,7 +197,7 @@ // The below cast always succeeds as it has already been verified in // 'declareVulkanLaunchFunc' that these are MemRefs with compatible element // types. - elementTypes.push_back(type.cast().getElementType()); + elementTypes.push_back(cast(type).getElementType()); } vulkanLaunchCallOp->setAttr(kSPIRVElementTypesAttrName, builder.getTypeArrayAttr(elementTypes)); diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -142,11 +142,11 @@ /// Returns a string representation from the given `type`. StringRef stringifyType(Type type) { - if (type.isa()) + if (isa(type)) return "Float"; - if (type.isa()) + if (isa(type)) return "Half"; - if (auto intType = type.dyn_cast()) { + if (auto intType = dyn_cast(type)) { if (intType.getWidth() == 32) return "Int32"; if (intType.getWidth() == 16) @@ -282,7 +282,7 @@ llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str(); // Special case for fp16 type. Since it is not a supported type in C we use // int16_t and bitcast the descriptor. - if (!useOpaquePointers && type.isa()) { + if (!useOpaquePointers && isa(type)) { auto memRefTy = getMemRefType(rank, IntegerType::get(&getContext(), 16)); ptrToMemRefDescriptor = builder.create( loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor); @@ -328,9 +328,8 @@ rank = 0; return success(); } - rank = llvmDescriptorTy.getBody()[3] - .cast() - .getNumElements(); + rank = + cast(llvmDescriptorTy.getBody()[3]).getNumElements(); return success(); } @@ -375,7 +374,7 @@ for (auto type : types) { std::string fnName = "bindMemRef" + std::to_string(i) + "D" + std::string(stringifyType(type)); - if (type.isa()) + if (isa(type)) type = IntegerType::get(&getContext(), 16); if (!module.lookupSymbol(fnName)) { auto fnType = LLVM::LLVMFunctionType::get( diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp --- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -24,8 +24,7 @@ MemRefDescriptor::MemRefDescriptor(Value descriptor) : StructBuilder(descriptor) { assert(value != nullptr && "value cannot be null"); - indexType = value.getType() - .cast() + indexType = cast(value.getType()) .getBody()[kOffsetPosInMemRefDescriptor]; } @@ -193,10 +192,10 @@ } LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() { - return value.getType() - .cast() - .getBody()[kAlignedPtrPosInMemRefDescriptor] - .cast(); + return cast( + value.getType() + .cast() + .getBody()[kAlignedPtrPosInMemRefDescriptor]); } Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc, diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -235,7 +235,7 @@ SmallVector unrankedMemrefs; SmallVector unrankedAddressSpaces; for (unsigned i = 0, e = operands.size(); i < e; ++i) { - if (auto memRefType = origTypes[i].dyn_cast()) { + if (auto memRefType = dyn_cast(origTypes[i])) { unrankedMemrefs.emplace_back(operands[i]); FailureOr addressSpace = getTypeConverter()->getMemRefAddressSpace(memRefType); @@ -276,7 +276,7 @@ unsigned unrankedMemrefPos = 0; for (unsigned i = 0, e = operands.size(); i < e; ++i) { Type type = origTypes[i]; - if (!type.isa()) + if (!isa(type)) continue; Value allocationSize = sizes[unrankedMemrefPos++]; UnrankedMemRefDescriptor desc(operands[i]); diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -260,7 +260,7 @@ if (!resultType) return {}; - auto structType = resultType.dyn_cast(); + auto structType = dyn_cast(resultType); if (structType) { // Struct types cannot be safely returned via C interface. Make this a // pointer argument, instead. @@ -272,7 +272,7 @@ auto converted = convertType(t); if (!converted || !LLVM::isCompatibleType(converted)) return {}; - if (t.isa()) + if (isa(t)) converted = getPointerType(converted); inputs.push_back(converted); } @@ -412,13 +412,13 @@ // Check if a memref type can be converted to a bare pointer. bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) { - if (type.isa()) + if (isa(type)) // Unranked memref is not supported in the bare pointer calling convention. return false; // Check that the memref has static shape, strides and offset. Otherwise, it // cannot be lowered to a bare pointer. - auto memrefTy = type.cast(); + auto memrefTy = cast(type); if (!memrefTy.hasStaticShape()) return false; @@ -476,7 +476,7 @@ Type LLVMTypeConverter::convertCallingConventionType(Type type, bool useBarePtrCallConv) { if (useBarePtrCallConv) - if (auto memrefTy = type.dyn_cast()) + if (auto memrefTy = dyn_cast(type)) return convertMemRefToBarePtr(memrefTy); return convertType(type); @@ -491,7 +491,7 @@ assert(stdTypes.size() == values.size() && "The number of types and values doesn't match"); for (unsigned i = 0, end = values.size(); i < end; ++i) - if (auto memrefTy = stdTypes[i].dyn_cast()) + if (auto memrefTy = dyn_cast(stdTypes[i])) values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this, memrefTy, values[i]); } @@ -569,19 +569,19 @@ if (useBarePtrCallConv) { // For the bare-ptr calling convention, we only have to extract the // aligned pointer of a memref. - if (auto memrefType = operand.getType().dyn_cast()) { + if (auto memrefType = dyn_cast(operand.getType())) { MemRefDescriptor desc(llvmOperand); llvmOperand = desc.alignedPtr(builder, loc); - } else if (operand.getType().isa()) { + } else if (isa(operand.getType())) { llvm_unreachable("Unranked memrefs are not supported"); } } else { - if (operand.getType().isa()) { + if (isa(operand.getType())) { UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand, promotedOperands); continue; } - if (auto memrefType = operand.getType().dyn_cast()) { + if (auto memrefType = dyn_cast(operand.getType())) { MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType, promotedOperands); continue; @@ -600,7 +600,7 @@ LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, SmallVectorImpl &result) { - if (auto memref = type.dyn_cast()) { + if (auto memref = dyn_cast(type)) { // In signatures, Memref descriptors are expanded into lists of // non-aggregate values. auto converted = @@ -610,7 +610,7 @@ result.append(converted.begin(), converted.end()); return success(); } - if (type.isa()) { + if (isa(type)) { auto converted = converter.getUnrankedMemRefDescriptorFields(); if (converted.empty()) return failure(); diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -27,10 +27,10 @@ } info.arraySizes.reserve(vectorType.getRank() - 1); auto llvmTy = info.llvmNDVectorTy; - while (llvmTy.isa()) { + while (isa(llvmTy)) { info.arraySizes.push_back( - llvmTy.cast().getNumElements()); - llvmTy = llvmTy.cast().getElementType(); + cast(llvmTy).getNumElements()); + llvmTy = cast(llvmTy).getElementType(); } if (!LLVM::isCompatibleVectorType(llvmTy)) return info; @@ -81,7 +81,7 @@ Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, std::function createOperand, ConversionPatternRewriter &rewriter) { - auto resultNDVectorType = op->getResult(0).getType().cast(); + auto resultNDVectorType = cast(op->getResult(0).getType()); auto resultTypeInfo = extractNDVectorTypeInfo(resultNDVectorType, typeConverter); auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy; @@ -114,7 +114,7 @@ return failure(); auto llvmNDVectorTy = operands[0].getType(); - if (!llvmNDVectorTy.isa()) + if (!isa(llvmNDVectorTy)) return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter, rewriter); diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -42,7 +42,7 @@ // The underlying descriptor type (e.g. LLVM) does not have layout // information. Canonicalizing the type at the level of std when going into // a library call avoids needing to introduce DialectCastOp. - if (auto memrefType = type.dyn_cast()) + if (auto memrefType = dyn_cast(type)) result.push_back(makeStridedLayoutDynamic(memrefType)); else result.push_back(type); @@ -96,7 +96,7 @@ SmallVector res; res.reserve(operands.size()); for (auto op : operands) { - auto memrefType = op.getType().dyn_cast(); + auto memrefType = dyn_cast(op.getType()); if (!memrefType) { res.push_back(op); continue; diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp --- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp +++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp @@ -106,7 +106,7 @@ VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const { Type opType = op.getType(); Location loc = op.getLoc(); - auto vecType = opType.template dyn_cast(); + auto vecType = dyn_cast(opType); if (!vecType) return rewriter.notifyMatchFailure(op, "not a vector operation"); @@ -117,7 +117,7 @@ Type resultElementType = vecType.getElementType(); Attribute initValueAttr; - if (resultElementType.isa()) + if (isa(resultElementType)) initValueAttr = FloatAttr::get(resultElementType, 0.0); else initValueAttr = IntegerAttr::get(resultElementType, 0); @@ -183,7 +183,7 @@ /// } /// } static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) { - assert(elementType.isa() && + assert(isa(elementType) && "non-integer element type for IPowIOp"); ImplicitLocOpBuilder builder = @@ -361,7 +361,7 @@ LogicalResult IPowIOpLowering::matchAndRewrite(math::IPowIOp op, PatternRewriter &rewriter) const { - auto baseType = op.getOperands()[0].getType().dyn_cast(); + auto baseType = dyn_cast(op.getOperands()[0].getType()); if (!baseType) return rewriter.notifyMatchFailure(op, "non-integer base operand"); @@ -411,8 +411,8 @@ /// } static func::FuncOp createElementFPowIFunc(ModuleOp *module, FunctionType funcType) { - auto baseType = funcType.getInput(0).cast(); - auto powType = funcType.getInput(1).cast(); + auto baseType = cast(funcType.getInput(0)); + auto powType = cast(funcType.getInput(1)); ImplicitLocOpBuilder builder = ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody()); @@ -586,7 +586,7 @@ LogicalResult FPowIOpLowering::matchAndRewrite(math::FPowIOp op, PatternRewriter &rewriter) const { - if (op.getType().template dyn_cast()) + if (dyn_cast(op.getType())) return rewriter.notifyMatchFailure(op, "non-scalar operation"); FunctionType funcType = getElementalFuncTypeForOp(op); @@ -649,7 +649,7 @@ /// return %out: i32 /// } static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { - if (!elementType.isa()) { + if (!isa(elementType)) { LLVM_DEBUG({ DBGS() << "non-integer element type for CtlzFunc; type was: "; elementType.print(llvm::dbgs()); @@ -751,7 +751,7 @@ /// operation. LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op, PatternRewriter &rewriter) const { - if (op.getType().template dyn_cast()) + if (dyn_cast(op.getType())) return rewriter.notifyMatchFailure(op, "non-scalar operation"); Type type = getElementTypeOrSelf(op.getResult().getType()); @@ -794,7 +794,7 @@ bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) { auto expTy = - getElementTypeOrSelf(op.getRhs().getType()).dyn_cast(); + dyn_cast(getElementTypeOrSelf(op.getRhs().getType())); return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent); } diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -79,14 +79,14 @@ auto resultType = op.getResult().getType(); auto boolZero = rewriter.getBoolAttr(false); - if (!operandType.template isa()) { + if (!isa(operandType)) { LLVM::ConstantOp zero = rewriter.create(loc, boolZero); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getOperand(), zero); return success(); } - auto vectorType = resultType.template dyn_cast(); + auto vectorType = dyn_cast(resultType); if (!vectorType) return failure(); @@ -122,17 +122,17 @@ auto loc = op.getLoc(); auto resultType = op.getResult().getType(); - auto floatType = getElementTypeOrSelf(resultType).cast(); + auto floatType = cast(getElementTypeOrSelf(resultType)); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); ConvertFastMath expAttrs(op); ConvertFastMath subAttrs(op); - if (!operandType.isa()) { + if (!isa(operandType)) { LLVM::ConstantOp one; if (LLVM::isCompatibleVectorType(operandType)) { one = rewriter.create( loc, operandType, - SplatElementsAttr::get(resultType.cast(), floatOne)); + SplatElementsAttr::get(cast(resultType), floatOne)); } else { one = rewriter.create(loc, operandType, floatOne); } @@ -143,7 +143,7 @@ return success(); } - auto vectorType = resultType.dyn_cast(); + auto vectorType = dyn_cast(resultType); if (!vectorType) return rewriter.notifyMatchFailure(op, "expected vector result type"); @@ -180,17 +180,17 @@ auto loc = op.getLoc(); auto resultType = op.getResult().getType(); - auto floatType = getElementTypeOrSelf(resultType).cast(); + auto floatType = cast(getElementTypeOrSelf(resultType)); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); ConvertFastMath addAttrs(op); ConvertFastMath logAttrs(op); - if (!operandType.isa()) { + if (!isa(operandType)) { LLVM::ConstantOp one = LLVM::isCompatibleVectorType(operandType) ? rewriter.create( loc, operandType, - SplatElementsAttr::get(resultType.cast(), + SplatElementsAttr::get(cast(resultType), floatOne)) : rewriter.create(loc, operandType, floatOne); @@ -202,7 +202,7 @@ return success(); } - auto vectorType = resultType.dyn_cast(); + auto vectorType = dyn_cast(resultType); if (!vectorType) return rewriter.notifyMatchFailure(op, "expected vector result type"); @@ -240,17 +240,17 @@ auto loc = op.getLoc(); auto resultType = op.getResult().getType(); - auto floatType = getElementTypeOrSelf(resultType).cast(); + auto floatType = cast(getElementTypeOrSelf(resultType)); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); ConvertFastMath sqrtAttrs(op); ConvertFastMath divAttrs(op); - if (!operandType.isa()) { + if (!isa(operandType)) { LLVM::ConstantOp one; if (LLVM::isCompatibleVectorType(operandType)) { one = rewriter.create( loc, operandType, - SplatElementsAttr::get(resultType.cast(), floatOne)); + SplatElementsAttr::get(cast(resultType), floatOne)); } else { one = rewriter.create(loc, operandType, floatOne); } @@ -261,7 +261,7 @@ return success(); } - auto vectorType = resultType.dyn_cast(); + auto vectorType = dyn_cast(resultType); if (!vectorType) return failure(); diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -75,7 +75,7 @@ VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const { auto opType = op.getType(); auto loc = op.getLoc(); - auto vecType = opType.template dyn_cast(); + auto vecType = dyn_cast(opType); if (!vecType) return failure(); @@ -107,7 +107,7 @@ LogicalResult PromoteOpToF32::matchAndRewrite(Op op, PatternRewriter &rewriter) const { auto opType = op.getType(); - if (!opType.template isa()) + if (!isa(opType)) return failure(); auto loc = op.getLoc(); @@ -127,7 +127,7 @@ PatternRewriter &rewriter) const { auto module = SymbolTable::getNearestSymbolTable(op); auto type = op.getType(); - if (!type.template isa()) + if (!isa(type)) return failure(); auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -34,7 +34,7 @@ /// given type is not a 32-bit scalar/vector type. static Value getScalarOrVectorI32Constant(Type type, int value, OpBuilder &builder, Location loc) { - if (auto vectorType = type.dyn_cast()) { + if (auto vectorType = dyn_cast(type)) { if (!vectorType.getElementType().isInteger(32)) return nullptr; SmallVector values(vectorType.getNumElements(), value); @@ -55,7 +55,7 @@ if (originalType.isIntOrIndexOrFloat()) return true; - if (auto vecTy = originalType.dyn_cast()) { + if (auto vecTy = dyn_cast(originalType)) { if (!vecTy.getElementType().isIntOrIndexOrFloat()) return false; if (vecTy.isScalable()) @@ -133,10 +133,10 @@ return failure(); FloatType floatType; - if (auto scalarType = copySignOp.getType().dyn_cast()) { + if (auto scalarType = dyn_cast(copySignOp.getType())) { floatType = scalarType; - } else if (auto vectorType = copySignOp.getType().dyn_cast()) { - floatType = vectorType.getElementType().cast(); + } else if (auto vectorType = dyn_cast(copySignOp.getType())) { + floatType = cast(vectorType.getElementType()); } else { return failure(); } @@ -151,7 +151,7 @@ Value valueMask = rewriter.create( loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u)); - if (auto vectorType = type.dyn_cast()) { + if (auto vectorType = dyn_cast(type)) { assert(vectorType.getRank() == 1); int count = vectorType.getNumElements(); intType = VectorType::get(count, intType); @@ -203,9 +203,9 @@ // We can only support 32-bit integer types for now. unsigned bitwidth = 0; - if (type.isa()) + if (isa(type)) bitwidth = type.getIntOrFloatBitWidth(); - if (auto vectorType = type.dyn_cast()) + if (auto vectorType = dyn_cast(type)) bitwidth = vectorType.getElementTypeBitWidth(); if (bitwidth != 32) return failure(); @@ -338,7 +338,7 @@ auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter); auto one = spirv::ConstantOp::getOne(ty, loc, rewriter); Value half; - if (VectorType vty = ty.dyn_cast()) { + if (VectorType vty = dyn_cast(ty)) { half = rewriter.create( loc, vty, DenseElementsAttr::get(vty, diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp @@ -58,7 +58,7 @@ Location loc, Value allocatedPtr, MemRefType memRefType, Type elementPtrType, LLVMTypeConverter &typeConverter) { - auto allocatedPtrTy = allocatedPtr.getType().cast(); + auto allocatedPtrTy = cast(allocatedPtr.getType()); unsigned memrefAddrSpace = *typeConverter.getMemRefAddressSpace(memRefType); if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace) allocatedPtr = rewriter.create( @@ -114,10 +114,10 @@ layout = &analysis->getAbove(op); } Type elementType = memRefType.getElementType(); - if (auto memRefElementType = elementType.dyn_cast()) + if (auto memRefElementType = dyn_cast(elementType)) return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, *layout); - if (auto memRefElementType = elementType.dyn_cast()) + if (auto memRefElementType = dyn_cast(elementType)) return getTypeConverter()->getUnrankedMemRefDescriptorSize( memRefElementType, *layout); return layout->getTypeSize(elementType); diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -184,10 +184,10 @@ rewriter.setInsertionPointToEnd(currentBlock); Value src = op.getSource(); - auto srcType = src.getType().dyn_cast(); + auto srcType = dyn_cast(src.getType()); Value srcNumElements = computeNumElements( srcType, [&]() -> Value { return desc.size(rewriter, loc, 0); }); - auto dstType = op.getType().cast(); + auto dstType = cast(op.getType()); Value dstNumElements = computeNumElements( dstType, [&]() -> Value { return op.getDynamicResultSize(); }); Value cond = rewriter.create( @@ -342,7 +342,7 @@ unsigned alignment = op.getAlignment(); auto loc = op.getLoc(); - auto srcMemRefType = op.getMemref().getType().cast(); + auto srcMemRefType = cast(op.getMemref().getType()); Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{}, rewriter); @@ -417,7 +417,7 @@ matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type operandType = dimOp.getSource().getType(); - if (operandType.isa()) { + if (isa(operandType)) { FailureOr extractedSize = extractSizeOfUnrankedMemRef( operandType, dimOp, adaptor.getOperands(), rewriter); if (failed(extractedSize)) @@ -425,7 +425,7 @@ rewriter.replaceOp(dimOp, {*extractedSize}); return success(); } - if (operandType.isa()) { + if (isa(operandType)) { rewriter.replaceOp( dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, adaptor.getOperands(), rewriter)}); @@ -441,7 +441,7 @@ ConversionPatternRewriter &rewriter) const { Location loc = dimOp.getLoc(); - auto unrankedMemRefType = operandType.cast(); + auto unrankedMemRefType = cast(operandType); auto scalarMemRefType = MemRefType::get({}, unrankedMemRefType.getElementType()); FailureOr maybeAddressSpace = @@ -492,10 +492,7 @@ return idx; if (auto constantOp = dimOp.getIndex().getDefiningOp()) - return constantOp.getValue() - .cast() - .getValue() - .getSExtValue(); + return cast(constantOp.getValue()).getValue().getSExtValue(); return std::nullopt; } @@ -506,7 +503,7 @@ Location loc = dimOp.getLoc(); // Take advantage if index is constant. - MemRefType memRefType = operandType.cast(); + MemRefType memRefType = cast(operandType); if (std::optional index = getConstantDimIndex(dimOp)) { int64_t i = *index; if (i >= 0 && i < memRefType.getRank()) { @@ -589,7 +586,7 @@ // Compute the loaded value and branch to the loop block. rewriter.setInsertionPointToEnd(initBlock); - auto memRefType = atomicOp.getMemref().getType().cast(); + auto memRefType = cast(atomicOp.getMemref().getType()); auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(), adaptor.getIndices(), rewriter); Value init = rewriter.create( @@ -712,7 +709,7 @@ Location loc, Value sizeBytes, Operation *op) const override { auto getGlobalOp = cast(op); - MemRefType type = getGlobalOp.getResult().getType().cast(); + MemRefType type = cast(getGlobalOp.getResult().getType()); // This is called after a type conversion, which would have failed if this // call fails. @@ -823,12 +820,12 @@ ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type operandType = op.getMemref().getType(); - if (auto unrankedMemRefType = operandType.dyn_cast()) { + if (auto unrankedMemRefType = dyn_cast(operandType)) { UnrankedMemRefDescriptor desc(adaptor.getMemref()); rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); return success(); } - if (auto rankedMemRefType = operandType.dyn_cast()) { + if (auto rankedMemRefType = dyn_cast(operandType)) { rewriter.replaceOp( op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); return success(); @@ -849,17 +846,17 @@ // and require source and result type to have the same rank. Therefore, // perform a sanity check that the underlying structs are the same. Once op // semantics are relaxed we can revisit. - if (srcType.isa() && dstType.isa()) + if (isa(srcType) && isa(dstType)) return success(typeConverter->convertType(srcType) == typeConverter->convertType(dstType)); // At least one of the operands is unranked type - assert(srcType.isa() || - dstType.isa()); + assert(isa(srcType) || + isa(dstType)); // Unranked to unranked cast is disallowed - return !(srcType.isa() && - dstType.isa()) + return !(isa(srcType) && + isa(dstType)) ? success() : failure(); } @@ -872,15 +869,15 @@ auto loc = memRefCastOp.getLoc(); // For ranked/ranked case, just keep the original descriptor. - if (srcType.isa() && dstType.isa()) + if (isa(srcType) && isa(dstType)) return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()}); - if (srcType.isa() && dstType.isa()) { + if (isa(srcType) && isa(dstType)) { // Casting ranked to unranked memref type // Set the rank in the destination from the memref type // Allocate space on the stack and copy the src memref descriptor // Set the ptr in the destination to the stack space - auto srcMemRefType = srcType.cast(); + auto srcMemRefType = cast(srcType); int64_t rank = srcMemRefType.getRank(); // ptr = AllocaOp sizeof(MemRefDescriptor) auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( @@ -905,7 +902,7 @@ memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); - } else if (srcType.isa() && dstType.isa()) { + } else if (isa(srcType) && isa(dstType)) { // Casting from unranked type to ranked. // The operation is assumed to be doing a correct cast. If the destination // type mismatches the unranked the type, it is undefined behavior. @@ -942,7 +939,7 @@ lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); - auto srcType = op.getSource().getType().dyn_cast(); + auto srcType = dyn_cast(op.getSource().getType()); MemRefDescriptor srcDesc(adaptor.getSource()); @@ -984,8 +981,8 @@ lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); - auto srcType = op.getSource().getType().cast(); - auto targetType = op.getTarget().getType().cast(); + auto srcType = cast(op.getSource().getType()); + auto targetType = cast(op.getTarget().getType()); // First make sure we have an unranked memref descriptor representation. auto makeUnranked = [&, this](Value ranked, MemRefType type) { @@ -1012,11 +1009,11 @@ auto stackSaveOp = rewriter.create(loc, getVoidPtrType()); - auto srcMemRefType = srcType.dyn_cast(); + auto srcMemRefType = dyn_cast(srcType); Value unrankedSource = srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType) : adaptor.getSource(); - auto targetMemRefType = targetType.dyn_cast(); + auto targetMemRefType = dyn_cast(targetType); Value unrankedTarget = targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType) : adaptor.getTarget(); @@ -1055,8 +1052,8 @@ LogicalResult matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto srcType = op.getSource().getType().cast(); - auto targetType = op.getTarget().getType().cast(); + auto srcType = cast(op.getSource().getType()); + auto targetType = cast(op.getTarget().getType()); auto isStaticShapeAndContiguousRowMajor = [](MemRefType type) { if (!type.hasStaticShape()) @@ -1077,7 +1074,7 @@ }; auto isContiguousMemrefType = [&](BaseMemRefType type) { - auto memrefType = type.dyn_cast(); + auto memrefType = dyn_cast(type); // We can use memcpy for memrefs if they have an identity layout or are // contiguous with an arbitrary offset. Ignore empty memrefs, which is a // special case handled by memrefCopy. @@ -1105,9 +1102,9 @@ Location loc = op.getLoc(); Type resultType = op.getDest().getType(); - if (auto resultTypeR = resultType.dyn_cast()) { + if (auto resultTypeR = dyn_cast(resultType)) { auto resultDescType = - typeConverter->convertType(resultTypeR).cast(); + cast(typeConverter->convertType(resultTypeR)); Type newPtrType = resultDescType.getBody()[0]; SmallVector descVals; @@ -1122,10 +1119,10 @@ rewriter.replaceOp(op, result); return success(); } - if (auto resultTypeU = resultType.dyn_cast()) { + if (auto resultTypeU = dyn_cast(resultType)) { // Since the type converter won't be doing this for us, get the address // space. - auto sourceType = op.getSource().getType().cast(); + auto sourceType = cast(op.getSource().getType()); FailureOr maybeSourceAddrSpace = getTypeConverter()->getMemRefAddressSpace(sourceType); if (failed(maybeSourceAddrSpace)) @@ -1217,7 +1214,7 @@ Value *allocatedPtr, Value *alignedPtr, Value *offset = nullptr) { Type operandType = originalOperand.getType(); - if (operandType.isa()) { + if (isa(operandType)) { MemRefDescriptor desc(convertedOperand); *allocatedPtr = desc.allocatedPtr(rewriter, loc); *alignedPtr = desc.alignedPtr(rewriter, loc); @@ -1228,8 +1225,8 @@ // These will all cause assert()s on unconvertible types. unsigned memorySpace = *typeConverter.getMemRefAddressSpace( - operandType.cast()); - Type elementType = operandType.cast().getElementType(); + cast(operandType)); + Type elementType = cast(operandType).getElementType(); Type llvmElementType = typeConverter.convertType(elementType); LLVM::LLVMPointerType elementPtrType = typeConverter.getPointerType(llvmElementType, memorySpace); @@ -1273,9 +1270,9 @@ memref::ReinterpretCastOp castOp, memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { MemRefType targetMemRefType = - castOp.getResult().getType().cast(); - auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) - .dyn_cast_or_null(); + cast(castOp.getResult().getType()); + auto llvmTargetDescriptorTy = dyn_cast_or_null( + typeConverter->convertType(targetMemRefType)); if (!llvmTargetDescriptorTy) return failure(); @@ -1339,13 +1336,12 @@ Type srcType, memref::ReshapeOp reshapeOp, memref::ReshapeOp::Adaptor adaptor, Value *descriptor) const { - auto shapeMemRefType = reshapeOp.getShape().getType().cast(); + auto shapeMemRefType = cast(reshapeOp.getShape().getType()); if (shapeMemRefType.hasStaticShape()) { MemRefType targetMemRefType = - reshapeOp.getResult().getType().cast(); - auto llvmTargetDescriptorTy = - typeConverter->convertType(targetMemRefType) - .dyn_cast_or_null(); + cast(reshapeOp.getResult().getType()); + auto llvmTargetDescriptorTy = dyn_cast_or_null( + typeConverter->convertType(targetMemRefType)); if (!llvmTargetDescriptorTy) return failure(); @@ -1426,8 +1422,7 @@ Value resultRank = shapeDesc.size(rewriter, loc, 0); // Extract address space and element type. - auto targetType = - reshapeOp.getResult().getType().cast(); + auto targetType = cast(reshapeOp.getResult().getType()); unsigned addressSpace = *getTypeConverter()->getMemRefAddressSpace(targetType); Type elementType = targetType.getElementType(); @@ -1695,7 +1690,7 @@ // Field 1: Copy the allocated pointer, used for malloc/free. Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); - auto srcMemRefType = viewOp.getSource().getType().cast(); + auto srcMemRefType = cast(viewOp.getSource().getType()); unsigned sourceMemorySpace = *getTypeConverter()->getMemRefAddressSpace(srcMemRefType); Value bitcastPtr; @@ -1848,7 +1843,7 @@ Location loc = extractStridedMetadataOp.getLoc(); Value source = extractStridedMetadataOp.getSource(); - auto sourceMemRefType = source.getType().cast(); + auto sourceMemRefType = cast(source.getType()); int64_t rank = sourceMemRefType.getRank(); SmallVector results; results.reserve(2 + rank * 2); @@ -1858,7 +1853,7 @@ Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc); MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), - extractStridedMetadataOp.getBaseBuffer().getType().cast(), + cast(extractStridedMetadataOp.getBaseBuffer().getType()), baseBuffer, alignedBuffer); results.push_back((Value)dstMemRef); diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp @@ -64,7 +64,7 @@ // Unknown dialect custom attributes are not supported by default. // Downstream callers should plug in more specialized ones. - auto intAttr = memorySpaceAttr.dyn_cast(); + auto intAttr = dyn_cast(memorySpaceAttr); if (!intAttr) return std::nullopt; unsigned memorySpace = intAttr.getInt(); @@ -118,7 +118,7 @@ // Unknown dialect custom attributes are not supported by default. // Downstream callers should plug in more specialized ones. - auto intAttr = memorySpaceAttr.dyn_cast(); + auto intAttr = dyn_cast(memorySpaceAttr); if (!intAttr) return std::nullopt; unsigned memorySpace = intAttr.getInt(); @@ -177,7 +177,7 @@ auto storageAttr = spirv::StorageClassAttr::get(memRefType.getContext(), *storage); - if (auto rankedType = memRefType.dyn_cast()) { + if (auto rankedType = dyn_cast(memRefType)) { return MemRefType::get(memRefType.getShape(), memRefType.getElementType(), rankedType.getLayout(), storageAttr); } @@ -203,9 +203,9 @@ /// Returns true if the given `type` is considered as legal for SPIR-V /// conversion. static bool isLegalType(Type type) { - if (auto memRefType = type.dyn_cast()) { + if (auto memRefType = dyn_cast(type)) { Attribute spaceAttr = memRefType.getMemorySpace(); - return spaceAttr && spaceAttr.isa(); + return spaceAttr && isa(spaceAttr); } return true; } @@ -213,7 +213,7 @@ /// Returns true if the given `attr` is considered as legal for SPIR-V /// conversion. static bool isLegalAttr(Attribute attr) { - if (auto typeAttr = attr.dyn_cast()) + if (auto typeAttr = dyn_cast(attr)) return isLegalType(typeAttr.getValue()); return true; } @@ -266,7 +266,7 @@ llvm::SmallVector newAttrs; newAttrs.reserve(op->getAttrs().size()); for (auto attr : op->getAttrs()) { - if (auto typeAttr = attr.getValue().dyn_cast()) { + if (auto typeAttr = dyn_cast(attr.getValue())) { auto newAttr = getTypeConverter()->convertType(typeAttr.getValue()); newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr)); } else { diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -93,11 +93,11 @@ /// can be lowered to SPIR-V. static bool isAllocationSupported(Operation *allocOp, MemRefType type) { if (isa(allocOp)) { - auto sc = type.getMemorySpace().dyn_cast_or_null(); + auto sc = dyn_cast_or_null(type.getMemorySpace()); if (!sc || sc.getValue() != spirv::StorageClass::Workgroup) return false; } else if (isa(allocOp)) { - auto sc = type.getMemorySpace().dyn_cast_or_null(); + auto sc = dyn_cast_or_null(type.getMemorySpace()); if (!sc || sc.getValue() != spirv::StorageClass::Function) return false; } else { @@ -110,7 +110,7 @@ return false; Type elementType = type.getElementType(); - if (auto vecType = elementType.dyn_cast()) + if (auto vecType = dyn_cast(elementType)) elementType = vecType.getElementType(); return elementType.isIntOrFloat(); } @@ -119,7 +119,7 @@ /// operations of unsupported integer bitwidths, based on the memref /// type. Returns std::nullopt on failure. static std::optional getAtomicOpScope(MemRefType type) { - auto sc = type.getMemorySpace().dyn_cast_or_null(); + auto sc = dyn_cast_or_null(type.getMemorySpace()); switch (sc.getValue()) { case spirv::StorageClass::StorageBuffer: return spirv::Scope::Device; @@ -324,11 +324,11 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - if (atomicOp.getType().isa()) + if (isa(atomicOp.getType())) return rewriter.notifyMatchFailure(atomicOp, "unimplemented floating-point case"); - auto memrefType = atomicOp.getMemref().getType().cast(); + auto memrefType = cast(atomicOp.getMemref().getType()); std::optional scope = getAtomicOpScope(memrefType); if (!scope) return rewriter.notifyMatchFailure(atomicOp, @@ -380,7 +380,7 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - MemRefType deallocType = operation.getMemref().getType().cast(); + MemRefType deallocType = cast(operation.getMemref().getType()); if (!isAllocationSupported(operation, deallocType)) return rewriter.notifyMatchFailure(operation, "unhandled allocation type"); rewriter.eraseOp(operation); @@ -395,7 +395,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = loadOp.getLoc(); - auto memrefType = loadOp.getMemref().getType().cast(); + auto memrefType = cast(loadOp.getMemref().getType()); if (!memrefType.getElementType().isSignlessInteger()) return failure(); @@ -419,18 +419,18 @@ Type pointeeType = pointerType.getPointeeType(); Type dstType; if (typeConverter.allows(spirv::Capability::Kernel)) { - if (auto arrayType = pointeeType.dyn_cast()) + if (auto arrayType = dyn_cast(pointeeType)) dstType = arrayType.getElementType(); else dstType = pointeeType; } else { // For Vulkan we need to extract element from wrapping struct and array. Type structElemType = - pointeeType.cast().getElementType(0); - if (auto arrayType = structElemType.dyn_cast()) + cast(pointeeType).getElementType(0); + if (auto arrayType = dyn_cast(structElemType)) dstType = arrayType.getElementType(); else - dstType = structElemType.cast().getElementType(); + dstType = cast(structElemType).getElementType(); } int dstBits = dstType.getIntOrFloatBitWidth(); assert(dstBits % srcBits == 0); @@ -509,7 +509,7 @@ LogicalResult LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto memrefType = loadOp.getMemref().getType().cast(); + auto memrefType = cast(loadOp.getMemref().getType()); if (memrefType.getElementType().isSignlessInteger()) return failure(); auto loadPtr = spirv::getElementPtr( @@ -526,7 +526,7 @@ LogicalResult IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto memrefType = storeOp.getMemref().getType().cast(); + auto memrefType = cast(storeOp.getMemref().getType()); if (!memrefType.getElementType().isSignlessInteger()) return failure(); @@ -553,18 +553,18 @@ Type pointeeType = pointerType.getPointeeType(); Type dstType; if (typeConverter.allows(spirv::Capability::Kernel)) { - if (auto arrayType = pointeeType.dyn_cast()) + if (auto arrayType = dyn_cast(pointeeType)) dstType = arrayType.getElementType(); else dstType = pointeeType; } else { // For Vulkan we need to extract element from wrapping struct and array. Type structElemType = - pointeeType.cast().getElementType(0); - if (auto arrayType = structElemType.dyn_cast()) + cast(pointeeType).getElementType(0); + if (auto arrayType = dyn_cast(structElemType)) dstType = arrayType.getElementType(); else - dstType = structElemType.cast().getElementType(); + dstType = cast(structElemType).getElementType(); } int dstBits = dstType.getIntOrFloatBitWidth(); @@ -651,21 +651,21 @@ return rewriter.notifyMatchFailure( loc, "address space casts require kernel capability"); - auto sourceType = addrCastOp.getSource().getType().dyn_cast(); + auto sourceType = dyn_cast(addrCastOp.getSource().getType()); if (!sourceType) return rewriter.notifyMatchFailure( loc, "SPIR-V lowering requires ranked memref types"); - auto resultType = addrCastOp.getResult().getType().cast(); + auto resultType = cast(addrCastOp.getResult().getType()); auto sourceStorageClassAttr = - sourceType.getMemorySpace().dyn_cast_or_null(); + dyn_cast_or_null(sourceType.getMemorySpace()); if (!sourceStorageClassAttr) return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) { diag << "source address space " << sourceType.getMemorySpace() << " must be a SPIR-V storage class"; }); auto resultStorageClassAttr = - resultType.getMemorySpace().dyn_cast_or_null(); + dyn_cast_or_null(resultType.getMemorySpace()); if (!resultStorageClassAttr) return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) { diag << "result address space " << resultType.getMemorySpace() @@ -709,7 +709,7 @@ LogicalResult StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto memrefType = storeOp.getMemref().getType().cast(); + auto memrefType = cast(storeOp.getMemref().getType()); if (memrefType.getElementType().isSignlessInteger()) return failure(); auto storePtr = spirv::getElementPtr( diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -28,7 +28,7 @@ /// `gpu.mma.sync` operation. static Type inferIntrinsicResultType(Type vectorResultType) { MLIRContext *ctx = vectorResultType.getContext(); - auto a = vectorResultType.cast(); + auto a = cast(vectorResultType); auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2); auto i32Ty = IntegerType::get(ctx, 32); auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); @@ -69,8 +69,8 @@ Type resultType, Value intrinsicResult, RewriterBase &rewriter) { MLIRContext *ctx = rewriter.getContext(); - auto structType = intrinsicResultType.dyn_cast(); - auto arrayType = resultType.dyn_cast(); + auto structType = dyn_cast(intrinsicResultType); + auto arrayType = dyn_cast(resultType); Type i32Ty = rewriter.getI32Type(); Type f32Ty = rewriter.getF32Type(); Type f64Ty = rewriter.getF64Type(); @@ -153,7 +153,7 @@ Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4); Type i4x8Ty = LLVM::getFixedVectorType(i4Ty, 8); Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); - auto arrayTy = operand.getType().cast(); + auto arrayTy = cast(operand.getType()); for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { Value toUse = rewriter.create(loc, operand, i); @@ -172,7 +172,7 @@ // For some element types (i32, f32, f64), we need to unpack the inner // vector/array type as well because the intrinsic expects individual // scalars to be provided. - VectorType innerArrayTy = arrayTy.getElementType().dyn_cast(); + VectorType innerArrayTy = dyn_cast(arrayTy.getElementType()); if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty || innerArrayTy.getElementType() == f64Ty || innerArrayTy.getElementType() == f32Ty)) { @@ -207,7 +207,7 @@ // of shape (NumRegisters, VectorRegister) where VectorRegister is the // vector type of the result and always 32 bits long. We bitcast the result // of the NVVM::LdMatrix to this vector type. - auto vectorResultType = op->getResultTypes()[0].dyn_cast(); + auto vectorResultType = dyn_cast(op->getResultTypes()[0]); if (!vectorResultType) { return failure(); } @@ -224,7 +224,7 @@ ldMatrixResultType = rewriter.getI32Type(); } - auto srcMemrefType = op.getSrcMemref().getType().cast(); + auto srcMemrefType = cast(op.getSrcMemref().getType()); Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrcMemref(), adaptor.getIndices(), rewriter); @@ -307,7 +307,7 @@ // TODO: add an attribute to the op to customize this behavior. std::optional overflow(std::nullopt); - if (aType.getElementType().isa()) + if (isa(aType.getElementType())) overflow = NVVM::MMAIntOverflow::satfinite; SmallVector matA = @@ -388,7 +388,7 @@ // constant. auto dstByteConstOp = dyn_cast(dstBytes.getDefiningOp()); - auto dstByteAttr = dstByteConstOp.getValue().dyn_cast(); + auto dstByteAttr = dyn_cast(dstByteConstOp.getValue()); int64_t dstByteVal = dstByteAttr.getValue().getSExtValue(); assert((dstByteVal == 4 || dstByteVal == 8 || dstByteVal == 16) && @@ -537,7 +537,7 @@ // TODO: add an attribute to the op to customize this behavior. std::optional overflow(std::nullopt); - if (aType.getElementType().isa()) + if (isa(aType.getElementType())) overflow = NVVM::MMAIntOverflow::satfinite; SmallVector matA = @@ -585,7 +585,7 @@ matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - auto dstMemrefType = op.getDst().getType().cast(); + auto dstMemrefType = cast(op.getDst().getType()); Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(), adaptor.getDstIndices(), rewriter); auto i8Ty = IntegerType::get(op.getContext(), 8); @@ -599,7 +599,7 @@ if (!getTypeConverter()->useOpaquePointers()) dstPtr = rewriter.create(loc, dstPointerType, dstPtr); - auto srcMemrefType = op.getSrc().getType().cast(); + auto srcMemrefType = cast(op.getSrc().getType()); FailureOr srcAddressSpace = getTypeConverter()->getMemRefAddressSpace(srcMemrefType); if (failed(srcAddressSpace)) diff --git a/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp b/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp --- a/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp +++ b/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp @@ -44,14 +44,14 @@ /// Check whether the type is a valid data descriptor. bool DataDescriptor::isValid(Value descriptor) { - if (auto type = descriptor.getType().dyn_cast()) { + if (auto type = dyn_cast(descriptor.getType())) { if (type.isIdentified() && type.getName().startswith(getStructName()) && type.getBody().size() == 3 && - (type.getBody()[kPtrBasePosInDataDescriptor] - .isa() || - type.getBody()[kPtrBasePosInDataDescriptor] - .isa()) && - type.getBody()[kPtrPosInDataDescriptor].isa() && + (isa( + type.getBody()[kPtrBasePosInDataDescriptor]) || + isa( + type.getBody()[kPtrBasePosInDataDescriptor])) && + isa(type.getBody()[kPtrPosInDataDescriptor]) && type.getBody()[kSizePosInDataDescriptor].isInteger(64)) return true; } @@ -104,7 +104,7 @@ // Traverse operands that were converted to MemRefDescriptors. if (auto memRefType = - originalDataOperand.getType().dyn_cast()) { + dyn_cast(originalDataOperand.getType())) { Type structType = converter->convertType(memRefType); Value memRefDescriptor = builder .create( @@ -127,7 +127,7 @@ descr.setPointer(builder, loc, dataPtr); descr.setSize(builder, loc, sizeBytes); convertedOperands.push_back(descr); - } else if (originalDataOperand.getType().isa()) { + } else if (isa(originalDataOperand.getType())) { convertedOperands.push_back(originalDataOperand); } else { // Type not supported. @@ -189,7 +189,7 @@ auto allDataOperandsAreConverted = [](ValueRange operands) { for (Value operand : operands) { if (!DataDescriptor::isValid(operand) && - !operand.getType().isa()) + !isa(operand.getType())) return false; } return true; diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -70,7 +70,7 @@ Value originalVariableOperand = curOp.getVariableOperand(idx); if (!originalVariableOperand) return failure(); - if (originalVariableOperand.getType().isa()) { + if (isa(originalVariableOperand.getType())) { // TODO: Support memref type in variable operands return rewriter.notifyMatchFailure(curOp, "memref is not supported yet"); @@ -101,7 +101,7 @@ Value originalVariableOperand = curOp.getVariableOperand(idx); if (!originalVariableOperand) return failure(); - if (originalVariableOperand.getType().isa()) { + if (isa(originalVariableOperand.getType())) { // TODO: Support memref type in variable operands return rewriter.notifyMatchFailure(curOp, "memref is not supported yet"); @@ -143,7 +143,7 @@ LogicalResult matchAndRewrite(omp::ReductionOp curOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (curOp.getAccumulator().getType().isa()) { + if (isa(curOp.getAccumulator().getType())) { // TODO: Support memref type in variable operands return rewriter.notifyMatchFailure(curOp, "memref is not supported yet"); } diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -219,7 +219,7 @@ // If this value corresponds to an operation, record that we are going to use // its location as part of a fused location. - bool isOperationValue = val && val.getType().isa(); + bool isOperationValue = val && isa(val.getType()); if (isOperationValue) locOps.insert(val); @@ -280,7 +280,7 @@ // The first operation retrieves the representative value of a range. // This applies only when the parent is a range of values and we were // requested to use a representative value (e.g., upward traversal). - if (parentVal.getType().isa() && + if (isa(parentVal.getType()) && usersPos->useRepresentative()) value = builder.create(loc, parentVal, 0); else @@ -327,7 +327,7 @@ break; } case Predicates::TypePos: { - if (parentVal.getType().isa()) + if (isa(parentVal.getType())) value = builder.create(loc, parentVal); else value = builder.create(loc, parentVal); @@ -357,11 +357,11 @@ case Predicates::TypeLiteralPos: { auto *typePos = cast(pos); Attribute rawTypeAttr = typePos->getValue(); - if (TypeAttr typeAttr = rawTypeAttr.dyn_cast()) + if (TypeAttr typeAttr = dyn_cast(rawTypeAttr)) value = builder.create(loc, typeAttr); else value = builder.create( - loc, rawTypeAttr.cast()); + loc, cast(rawTypeAttr)); break; } default: @@ -410,7 +410,7 @@ } case Predicates::TypeQuestion: { auto *ans = cast(answer); - if (val.getType().isa()) + if (isa(val.getType())) builder.create( loc, val, ans->getValue().cast(), success, failure); else @@ -554,7 +554,7 @@ OperationNameAnswer>(val, defaultDest, builder, children); case Predicates::TypeQuestion: - if (val.getType().isa()) { + if (isa(val.getType())) { return createSwitchOp( val, defaultDest, builder, children); } @@ -745,7 +745,7 @@ // Handle the case where there is a single range representing all of the // result types. OperandRange resultTys = operationOp.getTypeValues(); - if (resultTys.size() == 1 && resultTys[0].getType().isa()) { + if (resultTys.size() == 1 && isa(resultTys[0].getType())) { Value &type = rewriteValues[resultTys[0]]; if (!type) { auto results = builder.create(loc, createdOp); @@ -762,7 +762,7 @@ Value &type = rewriteValues[it.value()]; if (type) continue; - bool isVariadic = it.value().getType().isa(); + bool isVariadic = isa(it.value().getType()); seenVariableLength |= isVariadic; // After a variable length result has been seen, we need to use result diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -41,14 +41,14 @@ /// Returns the number of non-range elements within `values`. static unsigned getNumNonRangeValues(ValueRange values) { return llvm::count_if(values.getTypes(), - [](Type type) { return !type.isa(); }); + [](Type type) { return !isa(type); }); } static void getTreePredicates(std::vector &predList, Value val, PredicateBuilder &builder, DenseMap &inputs, AttributePosition *pos) { - assert(val.getType().isa() && "expected attribute type"); + assert(isa(val.getType()) && "expected attribute type"); pdl::AttributeOp attr = cast(val.getDefiningOp()); predList.emplace_back(pos, builder.getIsNotNull()); @@ -65,7 +65,7 @@ DenseMap &inputs, Position *pos) { Type valueType = val.getType(); - bool isVariadic = valueType.isa(); + bool isVariadic = isa(valueType); // If this is a typed operand, add a type constraint. TypeSwitch(val.getDefiningOp()) @@ -111,7 +111,7 @@ PredicateBuilder &builder, DenseMap &inputs, OperationPosition *pos, std::optional ignoreOperand = std::nullopt) { - assert(val.getType().isa() && "expected operation"); + assert(isa(val.getType()) && "expected operation"); pdl::OperationOp op = cast(val.getDefiningOp()); OperationPosition *opPos = cast(pos); @@ -148,7 +148,7 @@ llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) { getTreePredicates( predList, attr, builder, inputs, - builder.getAttribute(opPos, attrName.cast().getValue())); + builder.getAttribute(opPos, cast(attrName).getValue())); } // Process the operands and results of the operation. For all values up to @@ -157,7 +157,7 @@ // concrete indices until runtime. If there is only one variadic operand // group, we treat it as all of the operands/results of the operation. /// Operands. - if (operands.size() == 1 && operands[0].getType().isa()) { + if (operands.size() == 1 && isa(operands[0].getType())) { // Ignore the operands if we are performing an upward traversal (in that // case, they have already been visited). if (opPos->isRoot() || opPos->isOperandDefiningOp()) @@ -166,7 +166,7 @@ } else { bool foundVariableLength = false; for (const auto &operandIt : llvm::enumerate(operands)) { - bool isVariadic = operandIt.value().getType().isa(); + bool isVariadic = isa(operandIt.value().getType()); foundVariableLength |= isVariadic; // Ignore the specified operand, usually because this position was @@ -182,7 +182,7 @@ } } /// Results. - if (types.size() == 1 && types[0].getType().isa()) { + if (types.size() == 1 && isa(types[0].getType())) { getTreePredicates(predList, types.front(), builder, inputs, builder.getType(builder.getAllResults(opPos))); return; @@ -190,7 +190,7 @@ bool foundVariableLength = false; for (auto [idx, typeValue] : llvm::enumerate(types)) { - bool isVariadic = typeValue.getType().isa(); + bool isVariadic = isa(typeValue.getType()); foundVariableLength |= isVariadic; auto *resultPos = foundVariableLength @@ -301,7 +301,7 @@ // Ensure that the result isn't null if the result has an index. auto *parentPos = cast(inputs.lookup(op.getParent())); - bool isVariadic = op.getType().isa(); + bool isVariadic = isa(op.getType()); std::optional index = op.getIndex(); resultPos = builder.getResultGroup(parentPos, index, isVariadic); if (index) @@ -458,7 +458,7 @@ // Special case when we pass all the operands in one range. // For those, the index is empty. if (operands.size() == 1 && - operands[0].getType().isa()) { + isa(operands[0].getType())) { toVisit.emplace(operands[0], entry.value, std::nullopt, entry.depth + 1); return; @@ -514,7 +514,7 @@ OperandRange operands = op.getOperandValues(); assert(index < operands.size() && "operand index out of range"); for (unsigned i = 0; i <= index; ++i) - if (operands[i].getType().isa()) + if (isa(operands[i].getType())) return true; return false; } @@ -542,7 +542,7 @@ } else if (useOperandGroup(operationOp, *opIndex.index)) { // We are querying an operand group. Type type = operationOp.getOperandValues()[*opIndex.index].getType(); - bool variadic = type.isa(); + bool variadic = isa(type); operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic); } else { // We are querying an individual operand. @@ -578,7 +578,7 @@ // Traverse up a group of results. auto *opPos = dyn_cast(pos); assert(opPos && "operations and results must be interleaved"); - bool isVariadic = value.getType().isa(); + bool isVariadic = isa(value.getType()); if (opIndex.index) pos = builder.getResultGroup(opPos, opIndex.index, isVariadic); else diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -441,7 +441,7 @@ Value iv, lowerBound, upperBound, step; std::tie(mappingAttribute, iv, lowerBound, upperBound, step) = config; auto annotation = - mappingAttribute.dyn_cast(); + dyn_cast(mappingAttribute); if (!annotation) return parallelOp.emitOpError() << "expected mapping attribute for lowering to GPU"; diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -51,8 +51,7 @@ Value reducedVal = matchReduction({block.getArguments()[1]}, /*redPos=*/0, combinerOps); - if (!reducedVal || !reducedVal.isa() || - combinerOps.size() != 1) + if (!reducedVal || !isa(reducedVal) || combinerOps.size() != 1) return false; return isa(combinerOps[0]) && @@ -155,7 +154,7 @@ /// Returns an attribute with the minimum (if `min` is set) or the maximum value /// (otherwise) for the given float type. static Attribute minMaxValueForFloat(Type type, bool min) { - auto fltType = type.cast(); + auto fltType = cast(type); return FloatAttr::get( type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min)); } @@ -164,7 +163,7 @@ /// the maximum value (otherwise) for the given integer type, regardless of its /// signedness semantics (only the width is considered). static Attribute minMaxValueForSignedInt(Type type, bool min) { - auto intType = type.cast(); + auto intType = cast(type); unsigned bitwidth = intType.getWidth(); return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth) : llvm::APInt::getSignedMaxValue(bitwidth)); @@ -174,7 +173,7 @@ /// the maximum value (otherwise) for the given integer type, regardless of its /// signedness semantics (only the width is considered). static Attribute minMaxValueForUnsignedInt(Type type, bool min) { - auto intType = type.cast(); + auto intType = cast(type); unsigned bitwidth = intType.getWidth(); return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth) : llvm::APInt::getAllOnes(bitwidth)); @@ -388,7 +387,7 @@ reductionVariables.reserve(parallelOp.getNumReductions()); for (Value init : parallelOp.getInitVals()) { assert((LLVM::isCompatibleType(init.getType()) || - init.getType().isa()) && + isa(init.getType())) && "cannot create a reduction variable if the type is not an LLVM " "pointer element"); Value storage = rewriter.create( diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -220,9 +220,8 @@ auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands); for (const auto &operand : llvm::enumerate(kernelOperands)) { // Check if the kernel's operand is a ranked memref. - auto memRefType = launchOp.getKernelOperand(operand.index()) - .getType() - .dyn_cast(); + auto memRefType = dyn_cast( + launchOp.getKernelOperand(operand.index()).getType()); if (!memRefType) return failure(); @@ -241,7 +240,7 @@ // LLVM dialect global variable. spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()]; auto pointeeType = - spirvGlobal.getType().cast().getPointeeType(); + cast(spirvGlobal.getType()).getPointeeType(); auto dstGlobalType = typeConverter->convertType(pointeeType); if (!dstGlobalType) return failure(); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -37,7 +37,7 @@ static bool isSignedIntegerOrVector(Type type) { if (type.isSignedInteger()) return true; - if (auto vecType = type.dyn_cast()) + if (auto vecType = dyn_cast(type)) return vecType.getElementType().isSignedInteger(); return false; } @@ -46,18 +46,18 @@ static bool isUnsignedIntegerOrVector(Type type) { if (type.isUnsignedInteger()) return true; - if (auto vecType = type.dyn_cast()) + if (auto vecType = dyn_cast(type)) return vecType.getElementType().isUnsignedInteger(); return false; } /// Returns the bit width of integer, float or vector of float or integer values static unsigned getBitWidth(Type type) { - assert((type.isIntOrFloat() || type.isa()) && + assert((type.isIntOrFloat() || isa(type)) && "bitwidth is not supported for this type"); if (type.isIntOrFloat()) return type.getIntOrFloatBitWidth(); - auto vecType = type.dyn_cast(); + auto vecType = dyn_cast(type); auto elementType = vecType.getElementType(); assert(elementType.isIntOrFloat() && "only integers and floats have a bitwidth"); @@ -66,29 +66,29 @@ /// Returns the bit width of LLVMType integer or vector. static unsigned getLLVMTypeBitWidth(Type type) { - return (LLVM::isCompatibleVectorType(type) ? LLVM::getVectorElementType(type) - : type) - .cast() + return cast((LLVM::isCompatibleVectorType(type) + ? LLVM::getVectorElementType(type) + : type)) .getWidth(); } /// Creates `IntegerAttribute` with all bits set for given type static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { - if (auto vecType = type.dyn_cast()) { - auto integerType = vecType.getElementType().cast(); + if (auto vecType = dyn_cast(type)) { + auto integerType = cast(vecType.getElementType()); return builder.getIntegerAttr(integerType, -1); } - auto integerType = type.cast(); + auto integerType = cast(type); return builder.getIntegerAttr(integerType, -1); } /// Creates `llvm.mlir.constant` with all bits set for the given type. static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter) { - if (srcType.isa()) { + if (isa(srcType)) { return rewriter.create( loc, dstType, - SplatElementsAttr::get(srcType.cast(), + SplatElementsAttr::get(cast(srcType), minusOneIntegerAttribute(srcType, rewriter))); } return rewriter.create( @@ -98,14 +98,14 @@ /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value. static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value) { - if (auto vecType = srcType.dyn_cast()) { - auto floatType = vecType.getElementType().cast(); + if (auto vecType = dyn_cast(srcType)) { + auto floatType = cast(vecType.getElementType()); return rewriter.create( loc, dstType, SplatElementsAttr::get(vecType, rewriter.getFloatAttr(floatType, value))); } - auto floatType = srcType.cast(); + auto floatType = cast(srcType); return rewriter.create( loc, dstType, rewriter.getFloatAttr(floatType, value)); } @@ -157,7 +157,7 @@ static Value optionallyBroadcast(Location loc, Value value, Type srcType, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { - if (auto vectorType = srcType.dyn_cast()) { + if (auto vectorType = dyn_cast(srcType)) { unsigned numElements = vectorType.getNumElements(); return broadcast(loc, value, numElements, typeConverter, rewriter); } @@ -251,7 +251,7 @@ TypeConverter &converter) { unsigned stride = type.getArrayStride(); Type elementType = type.getElementType(); - auto sizeInBytes = elementType.cast().getSizeInBytes(); + auto sizeInBytes = cast(elementType).getSizeInBytes(); if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride)) return std::nullopt; @@ -319,10 +319,9 @@ indices.insert(indices.begin(), zero); rewriter.replaceOpWithNewOp( op, dstType, - typeConverter.convertType(op.getBasePtr() - .getType() - .cast() - .getPointeeType()), + typeConverter.convertType( + cast(op.getBasePtr().getType()) + .getPointeeType()), adaptor.getBasePtr(), indices); return success(); } @@ -397,7 +396,7 @@ matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = constOp.getType(); - if (!srcType.isa() && !srcType.isIntOrFloat()) + if (!isa(srcType) && !srcType.isIntOrFloat()) return failure(); auto dstType = typeConverter.convertType(srcType); @@ -413,15 +412,15 @@ isUnsignedIntegerOrVector(srcType)) { auto signlessType = rewriter.getIntegerType(getBitWidth(srcType)); - if (srcType.isa()) { - auto dstElementsAttr = constOp.getValue().cast(); + if (isa(srcType)) { + auto dstElementsAttr = cast(constOp.getValue()); rewriter.replaceOpWithNewOp( constOp, dstType, dstElementsAttr.mapValues( signlessType, [&](const APInt &value) { return value; })); return success(); } - auto srcAttr = constOp.getValue().cast(); + auto srcAttr = cast(constOp.getValue()); auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue()); rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); @@ -454,17 +453,17 @@ // Create a constant that holds the size of the `Base`. IntegerType integerType; - if (auto vecType = srcType.dyn_cast()) - integerType = vecType.getElementType().cast(); + if (auto vecType = dyn_cast(srcType)) + integerType = cast(vecType.getElementType()); else - integerType = srcType.cast(); + integerType = cast(srcType); auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType)); Value size = - srcType.isa() + isa(srcType) ? rewriter.create( loc, dstType, - SplatElementsAttr::get(srcType.cast(), baseSize)) + SplatElementsAttr::get(cast(srcType), baseSize)) : rewriter.create(loc, dstType, baseSize); // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit @@ -573,9 +572,9 @@ return failure(); Type containerType = op.getComposite().getType(); - if (containerType.isa()) { + if (isa(containerType)) { Location loc = op.getLoc(); - IntegerAttr value = op.getIndices()[0].cast(); + IntegerAttr value = cast(op.getIndices()[0]); Value index = createI32ConstantOf(loc, rewriter, value.getInt()); rewriter.replaceOpWithNewOp( op, dstType, adaptor.getComposite(), index); @@ -605,9 +604,9 @@ return failure(); Type containerType = op.getComposite().getType(); - if (containerType.isa()) { + if (isa(containerType)) { Location loc = op.getLoc(); - IntegerAttr value = op.getIndices()[0].cast(); + IntegerAttr value = cast(op.getIndices()[0]); Value index = createI32ConstantOf(loc, rewriter, value.getInt()); rewriter.replaceOpWithNewOp( op, dstType, adaptor.getComposite(), adaptor.getObject(), index); @@ -732,7 +731,7 @@ if (op.getInitializer()) return failure(); - auto srcType = op.getType().cast(); + auto srcType = cast(op.getType()); auto dstType = typeConverter.convertType(srcType.getPointeeType()); if (!dstType) return failure(); @@ -946,12 +945,12 @@ Location loc = notOp.getLoc(); IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter); - auto mask = srcType.template isa() - ? rewriter.create( - loc, dstType, - SplatElementsAttr::get( - srcType.template cast(), minusOne)) - : rewriter.create(loc, dstType, minusOne); + auto mask = + isa(srcType) + ? rewriter.create( + loc, dstType, + SplatElementsAttr::get(cast(srcType), minusOne)) + : rewriter.create(loc, dstType, minusOne); rewriter.template replaceOpWithNewOp(notOp, dstType, notOp.getOperand(), mask); return success(); @@ -1262,9 +1261,9 @@ ConversionPatternRewriter &rewriter) const override { auto srcType = varOp.getType(); // Initialization is supported for scalars and vectors only. - auto pointerTo = srcType.cast().getPointeeType(); + auto pointerTo = cast(srcType).getPointeeType(); auto init = varOp.getInitializer(); - if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa()) + if (init && !pointerTo.isIntOrFloat() && !isa(pointerTo)) return failure(); auto dstType = typeConverter.convertType(srcType); @@ -1303,7 +1302,7 @@ return failure(); if (typeConverter.useOpaquePointers() && - dstType.isa()) { + isa(dstType)) { rewriter.replaceOp(bitcastOp, adaptor.getOperand()); return success(); } @@ -1416,8 +1415,8 @@ auto components = adaptor.getComponents(); auto vector1 = adaptor.getVector1(); auto vector2 = adaptor.getVector2(); - int vector1Size = vector1.getType().cast().getNumElements(); - int vector2Size = vector2.getType().cast().getNumElements(); + int vector1Size = cast(vector1.getType()).getNumElements(); + int vector2Size = cast(vector2.getType()).getNumElements(); if (vector1Size == vector2Size) { rewriter.replaceOpWithNewOp( op, vector1, vector2, @@ -1426,16 +1425,16 @@ } auto dstType = typeConverter.convertType(op.getType()); - auto scalarType = dstType.cast().getElementType(); + auto scalarType = cast(dstType).getElementType(); auto componentsArray = components.getValue(); auto *context = rewriter.getContext(); auto llvmI32Type = IntegerType::get(context, 32); Value targetOp = rewriter.create(loc, dstType); for (unsigned i = 0; i < componentsArray.size(); i++) { - if (!componentsArray[i].isa()) + if (!isa(componentsArray[i])) return op.emitError("unable to support non-constant component"); - int indexVal = componentsArray[i].cast().getInt(); + int indexVal = cast(componentsArray[i]).getInt(); if (indexVal == -1) continue; diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -59,7 +59,7 @@ matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { // For now, only error-free types are supported by this lowering. - if (op.getType().template isa()) + if (isa(op.getType())) return failure(); rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), @@ -127,7 +127,7 @@ ConversionPatternRewriter &rewriter) const { // For now, this lowering is only defined on `tensor` operands, not // on shapes. - if (op.getType().isa()) + if (isa(op.getType())) return failure(); auto loc = op.getLoc(); @@ -189,7 +189,7 @@ // For now, this lowering supports only extent tensors, not `shape.shape` // types. - if (op.getType().isa()) + if (isa(op.getType())) return failure(); auto loc = op.getLoc(); @@ -242,7 +242,7 @@ // For now, this lowering is only defined on `tensor` operands, not // on shapes. if (!llvm::all_of(op.getShapes(), - [](Value v) { return !v.getType().isa(); })) + [](Value v) { return !isa(v.getType()); })) return failure(); auto loc = op.getLoc(); @@ -363,13 +363,13 @@ GetExtentOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // For now, only error-free types are supported by this lowering. - if (op.getType().isa()) + if (isa(op.getType())) return failure(); // Derive shape extent directly from shape origin if possible. This // circumvents the necessity to materialize the shape in memory. if (auto shapeOfOp = op.getShape().getDefiningOp()) { - if (shapeOfOp.getArg().getType().isa()) { + if (isa(shapeOfOp.getArg().getType())) { rewriter.replaceOpWithNewOp(op, shapeOfOp.getArg(), adaptor.getDim()); return success(); @@ -397,7 +397,7 @@ RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // For now, this lowering supports only error-free types. - if (op.getType().isa()) + if (isa(op.getType())) return failure(); rewriter.replaceOpWithNewOp(op, adaptor.getShape(), 0); @@ -420,7 +420,7 @@ ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // For now, this lowering is only defined on `tensor` operands. - if (op.getShape().getType().isa()) + if (isa(op.getShape().getType())) return failure(); auto loc = op.getLoc(); @@ -499,7 +499,7 @@ ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (!llvm::all_of(op.getShapes(), - [](Value v) { return !v.getType().isa(); })) + [](Value v) { return !isa(v.getType()); })) return failure(); Type i1Ty = rewriter.getI1Type(); @@ -570,18 +570,18 @@ ConversionPatternRewriter &rewriter) const { // For now, only error-free types are supported by this lowering. - if (op.getType().isa()) + if (isa(op.getType())) return failure(); // For ranked tensor arguments, lower to `tensor.from_elements`. auto loc = op.getLoc(); Value tensor = adaptor.getArg(); Type tensorTy = tensor.getType(); - if (tensorTy.isa()) { + if (isa(tensorTy)) { // Build values for individual extents. SmallVector extentValues; - RankedTensorType rankedTensorTy = tensorTy.cast(); + RankedTensorType rankedTensorTy = cast(tensorTy); int64_t rank = rankedTensorTy.getRank(); for (int64_t i = 0; i < rank; i++) { if (rankedTensorTy.isDynamicDim(i)) { @@ -634,7 +634,7 @@ // Error conditions are not implemented, only lower if all operands and // results are extent tensors. if (llvm::any_of(ValueRange{op.getOperand(), op.getHead(), op.getTail()}, - [](Value v) { return v.getType().isa(); })) + [](Value v) { return isa(v.getType()); })) return failure(); ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -667,7 +667,7 @@ LogicalResult matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!adaptor.getInput().getType().isa()) + if (!isa(adaptor.getInput().getType())) return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); rewriter.replaceOpWithNewOp(op, op.getType(), diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp --- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp +++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp @@ -44,7 +44,7 @@ LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto tensorType = extractOp.getTensor().getType().cast(); + auto tensorType = cast(extractOp.getTensor().getType()); if (!tensorType.hasStaticShape()) return rewriter.notifyMatchFailure(extractOp, "non-static tensor"); diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp --- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp +++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp @@ -34,14 +34,14 @@ }; Type matchContainerType(Type element, Type container) { - if (auto shapedTy = container.dyn_cast()) + if (auto shapedTy = dyn_cast(container)) return shapedTy.clone(element); return element; } TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) { - if (auto shapedTy = type.dyn_cast()) { + if (auto shapedTy = dyn_cast(type)) { Type eTy = shapedTy.getElementType(); APInt valueInt(eTy.getIntOrFloatBitWidth(), value); return DenseIntElementsAttr::get(shapedTy, valueInt); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -36,7 +36,7 @@ createConstFromIntAttribute(Operation *op, const std::string &attrName, Type requiredAttrType, OpBuilder &rewriter) { auto castedN = static_cast( - op->getAttr(attrName).cast().getValue().getSExtValue()); + cast(op->getAttr(attrName)).getValue().getSExtValue()); return rewriter.create( op->getLoc(), IntegerAttr::get(requiredAttrType, castedN)); } @@ -47,13 +47,13 @@ PatternRewriter &rewriter) { Location loc = op->getLoc(); auto elementTy = - op->getOperand(0).getType().cast().getElementType(); + cast(op->getOperand(0).getType()).getElementType(); // tosa::AbsOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { auto zero = rewriter.create( loc, rewriter.getZeroAttr(elementTy)); auto cmp = rewriter.create(loc, arith::CmpIPredicate::sgt, @@ -63,21 +63,21 @@ } // tosa::AddOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::SubOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::MulOp - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { if (dyn_cast(op).getShift() != 0) { (void)rewriter.notifyMatchFailure(op, "Cannot have shift value for float"); @@ -87,21 +87,21 @@ } // tosa::DivOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::ReciprocalOp - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { auto one = rewriter.create(loc, FloatAttr::get(elementTy, 1)); return rewriter.create(loc, resultTypes, one, args[0]); } - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { Value a = args[0]; Value b = args[1]; auto shift = - op->getAttr("shift").cast().getValue().getSExtValue(); + cast(op->getAttr("shift")).getValue().getSExtValue(); if (shift > 0) { auto shiftConst = rewriter.create(loc, shift, /*bitwidth=*/8); @@ -134,17 +134,17 @@ } // tosa::NegateOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); - if (isa(op) && elementTy.isa() && + if (isa(op) && isa(elementTy) && !cast(op).getQuantizationInfo()) { auto constant = rewriter.create(loc, IntegerAttr::get(elementTy, 0)); return rewriter.create(loc, resultTypes, constant, args[0]); } - if (isa(op) && elementTy.isa() && + if (isa(op) && isa(elementTy) && cast(op).getQuantizationInfo()) { auto quantizationInfo = cast(op).getQuantizationInfo(); int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth(); @@ -190,15 +190,15 @@ } // tosa::BitwiseAndOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::BitwiseOrOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::BitwiseNotOp - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { auto allOnesAttr = rewriter.getIntegerAttr( elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth())); auto allOnes = rewriter.create(loc, allOnesAttr); @@ -206,21 +206,21 @@ } // tosa::BitwiseXOrOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::LogicalLeftShiftOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::LogicalRightShiftOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::ArithmeticRightShiftOp - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { auto result = rewriter.create(loc, resultTypes, args); - auto round = op->getAttr("round").cast().getValue(); + auto round = cast(op->getAttr("round")).getValue(); if (!round) { return result; } @@ -256,7 +256,7 @@ } // tosa::ClzOp - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { return rewriter.create(loc, elementTy, args[0]); } @@ -280,27 +280,27 @@ return rewriter.create(loc, resultTypes, args); // tosa::PowOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::RsqrtOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::LogOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::ExpOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::TanhOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::GreaterOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, arith::CmpFPredicate::OGT, args[0], args[1]); @@ -309,7 +309,7 @@ args[0], args[1]); // tosa::GreaterEqualOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, arith::CmpFPredicate::OGE, args[0], args[1]); @@ -318,7 +318,7 @@ args[0], args[1]); // tosa::EqualOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, arith::CmpFPredicate::OEQ, args[0], args[1]); @@ -328,13 +328,13 @@ // tosa::SelectOp if (isa(op)) { - elementTy = op->getOperand(1).getType().cast().getElementType(); - if (elementTy.isa() || elementTy.isa()) + elementTy = cast(op->getOperand(1).getType()).getElementType(); + if (isa(elementTy) || isa(elementTy)) return rewriter.create(loc, args[0], args[1], args[2]); } // tosa::MaximumOp - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args[0], args[1]); } @@ -345,7 +345,7 @@ } // tosa::MinimumOp - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args[0], args[1]); } @@ -356,21 +356,21 @@ } // tosa::CeilOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::FloorOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::ClampOp - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { bool losesInfo = false; - APFloat minApf = op->getAttr("min_fp").cast().getValue(); - APFloat maxApf = op->getAttr("max_fp").cast().getValue(); - minApf.convert(elementTy.cast().getFloatSemantics(), + APFloat minApf = cast(op->getAttr("min_fp")).getValue(); + APFloat maxApf = cast(op->getAttr("max_fp")).getValue(); + minApf.convert(cast(elementTy).getFloatSemantics(), APFloat::rmNearestTiesToEven, &losesInfo); - maxApf.convert(elementTy.cast().getFloatSemantics(), + maxApf.convert(cast(elementTy).getFloatSemantics(), APFloat::rmNearestTiesToEven, &losesInfo); auto min = rewriter.create( loc, elementTy, rewriter.getFloatAttr(elementTy, minApf)); @@ -379,12 +379,12 @@ return clampFloatHelper(loc, args[0], min, max, rewriter); } - if (isa(op) && elementTy.isa()) { - auto intTy = elementTy.cast(); + if (isa(op) && isa(elementTy)) { + auto intTy = cast(elementTy); int32_t min = static_cast( - op->getAttr("min_int").cast().getValue().getSExtValue()); + cast(op->getAttr("min_int")).getValue().getSExtValue()); int32_t max = static_cast( - op->getAttr("max_int").cast().getValue().getSExtValue()); + cast(op->getAttr("max_int")).getValue().getSExtValue()); if (intTy.isUnsignedInteger()) { min = std::max(min, 0); @@ -408,7 +408,7 @@ } // tosa::SigmoidOp - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { auto one = rewriter.create(loc, FloatAttr::get(elementTy, 1)); auto negate = rewriter.create(loc, resultTypes, args[0]); @@ -427,11 +427,11 @@ if (srcTy == dstTy) return args.front(); - if (srcTy.isa() && dstTy.isa() && bitExtend) + if (isa(srcTy) && isa(dstTy) && bitExtend) return rewriter.create(loc, resultTypes, args, std::nullopt); - if (srcTy.isa() && dstTy.isa() && !bitExtend) + if (isa(srcTy) && isa(dstTy) && !bitExtend) return rewriter.create(loc, resultTypes, args, std::nullopt); @@ -440,13 +440,13 @@ return rewriter.create(loc, resultTypes, args, std::nullopt); - if (srcTy.isInteger(1) && dstTy.isa() && bitExtend) + if (srcTy.isInteger(1) && isa(dstTy) && bitExtend) return rewriter.create(loc, resultTypes, args, std::nullopt); // Unsigned integers need an unrealized cast so that they can be passed // to UIToFP. - if (srcTy.isUnsignedInteger() && dstTy.isa()) { + if (srcTy.isUnsignedInteger() && isa(dstTy)) { auto unrealizedCast = rewriter .create( @@ -463,7 +463,7 @@ std::nullopt); // Casting to boolean, floats need to only be checked as not-equal to zero. - if (srcTy.isa() && dstTy.isInteger(1)) { + if (isa(srcTy) && dstTy.isInteger(1)) { Value zero = rewriter.create( loc, rewriter.getFloatAttr(srcTy, 0.0)); return rewriter.create(loc, arith::CmpFPredicate::UNE, @@ -490,18 +490,18 @@ // Casting to boolean, integers need to only be checked as not-equal to // zero. - if (srcTy.isa() && dstTy.isInteger(1)) { + if (isa(srcTy) && dstTy.isInteger(1)) { Value zero = rewriter.create( loc, 0, srcTy.getIntOrFloatBitWidth()); return rewriter.create(loc, arith::CmpIPredicate::ne, args.front(), zero); } - if (srcTy.isa() && dstTy.isa() && bitExtend) + if (isa(srcTy) && isa(dstTy) && bitExtend) return rewriter.create(loc, resultTypes, args, std::nullopt); - if (srcTy.isa() && dstTy.isa() && !bitExtend) { + if (isa(srcTy) && isa(dstTy) && !bitExtend) { return rewriter.create(loc, dstTy, args[0]); } } @@ -520,7 +520,7 @@ "All TOSA elementwise ops should only return a single result."); auto results = operation->getResults(); - auto resultTy = operation->getResult(0).getType().dyn_cast(); + auto resultTy = dyn_cast(operation->getResult(0).getType()); if (!resultTy) return rewriter.notifyMatchFailure(operation, @@ -538,10 +538,10 @@ SmallVector emptyTensors; SmallVector dynDims; - dynDims.resize(results.front().getType().cast().getRank()); + dynDims.resize(cast(results.front().getType()).getRank()); for (auto arg : operation->getOperands()) { - auto operandTy = arg.getType().cast(); + auto operandTy = cast(arg.getType()); for (int i = 0; i < operandTy.getRank(); i++) { if (operandTy.isDynamicDim(i) && !dynDims[i]) dynDims[i] = rewriter.create(loc, arg, i); @@ -551,7 +551,7 @@ SmallVector filteredDims = condenseValues(dynDims); for (auto result : results) { - auto resultTy = result.getType().template cast(); + auto resultTy = cast(result.getType()); emptyTensors.push_back(rewriter.create( loc, resultTy.getShape(), resultTy.getElementType(), filteredDims)); opResultTypes.push_back(result.getType()); @@ -566,7 +566,7 @@ // Input indexing maps may be broadcasted. for (Value operand : operation->getOperands()) { - ShapedType type = operand.getType().cast(); + ShapedType type = cast(operand.getType()); if (type.getShape() == resultTy.getShape()) { operands.push_back(operand); @@ -627,33 +627,33 @@ // attribute type varies depending on the element type required. static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.getFloatAttr(elementTy, 0.0); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.getIntegerAttr(elementTy, 0); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.getFloatAttr(elementTy, 1.0); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.getIntegerAttr(elementTy, 1); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.getFloatAttr( elementTy, APFloat::getLargest( - elementTy.cast().getFloatSemantics(), false)); + cast(elementTy).getFloatSemantics(), false)); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.getIntegerAttr( elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth())); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.getFloatAttr( elementTy, APFloat::getLargest( - elementTy.cast().getFloatSemantics(), true)); + cast(elementTy).getFloatSemantics(), true)); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.getIntegerAttr( elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())); @@ -663,12 +663,12 @@ if (isa(op) && elementTy.isInteger(1)) return rewriter.getIntegerAttr(elementTy, APInt::getZero(1)); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.getFloatAttr( elementTy, APFloat::getLargest( - elementTy.cast().getFloatSemantics(), true)); + cast(elementTy).getFloatSemantics(), true)); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.getIntegerAttr( elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())); @@ -682,37 +682,37 @@ Type elementTy, PatternRewriter &rewriter) { Location loc = op->getLoc(); - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args); } - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args); } - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args); } - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args); } - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args[0], args[1]); } - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { auto predicate = rewriter.create( loc, arith::CmpIPredicate::slt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args[0], args[1]); } - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { auto predicate = rewriter.create( loc, arith::CmpIPredicate::sgt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); @@ -733,8 +733,8 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter) { auto loc = op->getLoc(); - auto inputTy = op->getOperand(0).getType().template cast(); - auto resultTy = op->getResult(0).getType().template cast(); + auto inputTy = cast(op->getOperand(0).getType()); + auto resultTy = cast(op->getResult(0).getType()); auto elementTy = resultTy.getElementType(); Value input = op->getOperand(0); @@ -799,7 +799,7 @@ SmallVector reassociationMap; uint64_t expandInputRank = - linalgOp.getResults()[0].getType().cast().getRank(); + cast(linalgOp.getResults()[0].getType()).getRank(); reassociationMap.resize(expandInputRank); for (uint64_t i = 0; i < expandInputRank; i++) { @@ -848,14 +848,14 @@ auto loc = op.getLoc(); auto input = op->getOperand(0); - auto resultTy = op.getType().cast(); + auto resultTy = cast(op.getType()); SmallVector dynDims; - dynDims.resize(op->getResult(0).getType().cast().getRank()); + dynDims.resize(cast(op->getResult(0).getType()).getRank()); SmallVector inputExprs; inputExprs.resize(resultTy.getRank()); - auto operandTy = input.getType().cast(); + auto operandTy = cast(input.getType()); for (const auto &permutation : llvm::enumerate(perms.getValues())) { auto index = permutation.index(); auto value = permutation.value().getZExtValue(); @@ -893,8 +893,8 @@ PatternRewriter &rewriter) const final { auto loc = op.getLoc(); auto input = op.getInput(); - auto inputTy = op.getInput().getType().cast(); - auto outputTy = op.getOutput().getType().cast(); + auto inputTy = cast(op.getInput().getType()); + auto outputTy = cast(op.getOutput().getType()); unsigned rank = inputTy.getRank(); // This is an illegal configuration. terminate and log an error @@ -1036,7 +1036,7 @@ // Saturate to the output size. IntegerType outIntType = - blockArgs.back().getType().cast(); + cast(blockArgs.back().getType()); unsigned outBitWidth = outIntType.getWidth(); int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue(); @@ -1089,8 +1089,8 @@ Location loc = op.getLoc(); ImplicitLocOpBuilder builder(loc, rewriter); auto input = op.getInput(); - auto inputTy = input.getType().cast(); - auto resultTy = op.getType().cast(); + auto inputTy = cast(input.getType()); + auto resultTy = cast(op.getType()); const bool isBilinear = op.getMode() == "BILINEAR"; auto inputH = inputTy.getDimSize(1); @@ -1186,8 +1186,8 @@ Location loc = op.getLoc(); ImplicitLocOpBuilder builder(loc, rewriter); auto input = op.getInput(); - auto inputTy = input.getType().dyn_cast(); - auto resultTy = op.getType().dyn_cast(); + auto inputTy = dyn_cast(input.getType()); + auto resultTy = dyn_cast(op.getType()); if (!inputTy || !resultTy) return rewriter.notifyMatchFailure(op, @@ -1282,8 +1282,8 @@ Location loc = op.getLoc(); ImplicitLocOpBuilder b(loc, rewriter); auto input = op.getInput(); - auto inputTy = input.getType().cast(); - auto resultTy = op.getType().cast(); + auto inputTy = cast(input.getType()); + auto resultTy = cast(op.getType()); auto resultETy = resultTy.getElementType(); auto imageH = inputTy.getShape()[1]; @@ -1573,8 +1573,8 @@ PatternRewriter &rewriter) const final { auto loc = op.getLoc(); Value input = op.getInput(); - auto inputTy = input.getType().template cast(); - auto resultTy = op.getType().template cast(); + auto inputTy = cast(input.getType()); + auto resultTy = cast(op.getType()); auto axis = op.getAxis(); SmallVector dynDims; @@ -1635,9 +1635,9 @@ ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto input = op.getInput1(); - auto inputTy = input.getType().cast(); + auto inputTy = cast(input.getType()); auto inputShape = inputTy.getShape(); - auto resultTy = op.getType().cast(); + auto resultTy = cast(op.getType()); auto elementTy = inputTy.getElementType(); int64_t rank = inputTy.getRank(); @@ -1710,14 +1710,14 @@ PatternRewriter &rewriter) const final { auto loc = argmaxOp.getLoc(); Value input = argmaxOp.getInput(); - auto inputTy = input.getType().cast(); - auto resultTy = argmaxOp.getOutput().getType().cast(); + auto inputTy = cast(input.getType()); + auto resultTy = cast(argmaxOp.getOutput().getType()); auto inElementTy = inputTy.getElementType(); auto outElementTy = resultTy.getElementType(); int axis = argmaxOp.getAxis(); auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy); - if (!outElementTy.isa()) + if (!isa(outElementTy)) return rewriter.notifyMatchFailure( argmaxOp, "tosa.arg_max to linalg.* requires integer-like result type"); @@ -1792,10 +1792,10 @@ rewriter.create(loc, axis)); Value predicate; - if (inElementTy.isa()) { + if (isa(inElementTy)) { predicate = rewriter.create( nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); - } else if (inElementTy.isa()) { + } else if (isa(inElementTy)) { predicate = rewriter.create( nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue); } else { @@ -1830,8 +1830,8 @@ auto indices = adaptor.getOperands()[1]; auto valuesTy = - op.getValues().getType().dyn_cast_or_null(); - auto resultTy = op.getType().cast(); + dyn_cast_or_null(op.getValues().getType()); + auto resultTy = cast(op.getType()); if (!valuesTy) return rewriter.notifyMatchFailure(op, "unranked tensors not supported"); @@ -1904,9 +1904,9 @@ auto loc = op.getLoc(); Value input = op.getInput(); Value table = op.getTable(); - auto inputTy = input.getType().cast(); - auto tableTy = table.getType().cast(); - auto resultTy = op.getType().cast(); + auto inputTy = cast(input.getType()); + auto tableTy = cast(table.getType()); + auto resultTy = cast(op.getType()); auto inputElementTy = inputTy.getElementType(); auto tableElementTy = tableTy.getElementType(); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -36,7 +36,7 @@ if (llvm::all_of(pad, [](int64_t p) { return p == 0; })) return input; - ShapedType inputTy = input.getType().cast(); + ShapedType inputTy = cast(input.getType()); Type inputETy = inputTy.getElementType(); auto inputShape = inputTy.getShape(); @@ -67,7 +67,7 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value conv, Value result, ArrayRef indexingMaps) { - ShapedType resultTy = conv.getType().cast(); + ShapedType resultTy = cast(conv.getType()); return rewriter .create( loc, resultTy, ValueRange({bias, conv}), result, indexingMaps, @@ -125,7 +125,7 @@ ArrayRef padAttr, ArrayRef strideAttr, ArrayRef dilationAttr, ArrayRef inputSizeDims, ArrayRef kernelSizeDims, OpBuilder &rewriter) { - ShapedType inputTy = input.getType().cast(); + ShapedType inputTy = cast(input.getType()); Type inputETy = inputTy.getElementType(); int64_t inputRank = inputTy.getRank(); @@ -187,11 +187,10 @@ Value weight = op->getOperand(1); Value bias = op->getOperand(2); - ShapedType inputTy = input.getType().template cast(); - ShapedType weightTy = weight.getType().template cast(); - ShapedType biasTy = bias.getType().template cast(); - ShapedType resultTy = - op->getResult(0).getType().template cast(); + ShapedType inputTy = cast(input.getType()); + ShapedType weightTy = cast(weight.getType()); + ShapedType biasTy = cast(bias.getType()); + ShapedType resultTy = cast(op->getResult(0).getType()); Type inputETy = inputTy.getElementType(); Type resultETy = resultTy.getElementType(); @@ -353,18 +352,18 @@ Value weight = op->getOperand(1); Value bias = op->getOperand(2); - ShapedType inputTy = input.getType().cast(); - ShapedType weightTy = weight.getType().cast(); - ShapedType biasTy = bias.getType().cast(); - ShapedType resultTy = op->getResult(0).getType().cast(); + ShapedType inputTy = cast(input.getType()); + ShapedType weightTy = cast(weight.getType()); + ShapedType biasTy = cast(bias.getType()); + ShapedType resultTy = cast(op->getResult(0).getType()); int64_t resultRank = resultTy.getRank(); Type inputETy = inputTy.getElementType(); Type resultETy = resultTy.getElementType(); - auto padAttr = op->getAttr("pad").cast(); - auto strideTosaAttr = op->getAttr("stride").cast(); - auto dilationTosaAttr = op->getAttr("dilation").cast(); + auto padAttr = cast(op->getAttr("pad")); + auto strideTosaAttr = cast(op->getAttr("stride")); + auto dilationTosaAttr = cast(op->getAttr("dilation")); if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) return rewriter.notifyMatchFailure( @@ -382,7 +381,7 @@ IntegerAttr kZp; if (isQuantized) { auto quantizationInfo = - op->getAttr("quantization_info").cast(); + cast(op->getAttr("quantization_info")); iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()); kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()); } @@ -394,7 +393,7 @@ TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy); if (isQuantized) { auto quantizationInfo = - op->getAttr("quantization_info").cast(); + cast(op->getAttr("quantization_info")); int64_t iZp = quantizationInfo.getInputZp(); int64_t intMin = @@ -505,14 +504,14 @@ ConversionPatternRewriter &rewriter) const final { Location loc = op.getLoc(); - auto outputTy = op.getType().cast(); + auto outputTy = cast(op.getType()); auto outputElementTy = outputTy.getElementType(); - auto firstOperandTy = op->getOperand(0).getType().cast(); - auto secondOperandTy = op->getOperand(1).getType().cast(); + auto firstOperandTy = cast(op->getOperand(0).getType()); + auto secondOperandTy = cast(op->getOperand(1).getType()); SmallVector dynDims; - dynDims.resize(op->getResult(0).getType().cast().getRank()); + dynDims.resize(cast(op->getResult(0).getType()).getRank()); if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(0)) { dynDims[0] = rewriter.create(loc, op->getOperand(0), 0); @@ -564,20 +563,20 @@ matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { Location loc = op.getLoc(); - auto outputTy = op.getType().cast(); + auto outputTy = cast(op.getType()); auto input = op.getInput(); - auto inputTy = input.getType().cast(); + auto inputTy = cast(input.getType()); auto bias = op.getBias(); auto weight = op.getWeight(); - auto weightTy = weight.getType().cast(); + auto weightTy = cast(weight.getType()); auto weightShape = weightTy.getShape(); auto outputETy = outputTy.getElementType(); SmallVector dynDims; - dynDims.resize(op->getResult(0).getType().cast().getRank()); + dynDims.resize(cast(op->getResult(0).getType()).getRank()); if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) { dynDims[0] = rewriter.create(loc, input, 0); @@ -676,9 +675,9 @@ PatternRewriter &rewriter) const final { Location loc = op.getLoc(); Value input = op.getInput(); - ShapedType inputTy = input.getType().cast(); + ShapedType inputTy = cast(input.getType()); - ShapedType resultTy = op.getType().template cast(); + ShapedType resultTy = cast(op.getType()); Type resultETy = inputTy.getElementType(); auto dynamicDimsOr = @@ -691,11 +690,10 @@ TypedAttr initialAttr; if (resultETy.isF32()) initialAttr = rewriter.getFloatAttr( - resultETy, - APFloat::getLargest(resultETy.cast().getFloatSemantics(), - true)); + resultETy, APFloat::getLargest( + cast(resultETy).getFloatSemantics(), true)); - if (resultETy.isa()) + if (isa(resultETy)) initialAttr = rewriter.getIntegerAttr( resultETy, APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth())); @@ -747,14 +745,14 @@ PatternRewriter &rewriter) const final { Location loc = op.getLoc(); Value input = op.getInput(); - ShapedType inputTy = input.getType().cast(); + ShapedType inputTy = cast(input.getType()); Type inElementTy = inputTy.getElementType(); - ShapedType resultTy = op.getType().template cast(); - Type resultETy = op.getType().cast().getElementType(); + ShapedType resultTy = cast(op.getType()); + Type resultETy = cast(op.getType()).getElementType(); Type accETy = - inElementTy.isa() ? rewriter.getI32Type() : inElementTy; + isa(inElementTy) ? rewriter.getI32Type() : inElementTy; ShapedType accTy = resultTy.clone(accETy); auto dynamicDimsOr = @@ -872,7 +870,7 @@ // a div however for quantized values input normalization had // to be applied. Value poolVal = args[0]; - if (accETy.isa()) { + if (isa(accETy)) { auto countF = rewriter.create(loc, accETy, count); poolVal = rewriter.create(loc, poolVal, countF) ->getResult(0); diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -134,8 +134,8 @@ LogicalResult matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - ShapedType operandTy = adaptor.getInput1().getType().cast(); - ShapedType resultTy = reshape.getType().template cast(); + ShapedType operandTy = cast(adaptor.getInput1().getType()); + ShapedType resultTy = cast(reshape.getType()); bool isDynamic = !operandTy.hasStaticShape(); if (isDynamic && resultTy.getRank() != 1) { @@ -172,8 +172,8 @@ LogicalResult matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - ShapedType operandTy = adaptor.getInput1().getType().cast(); - ShapedType resultTy = reshape.getType().template cast(); + ShapedType operandTy = cast(adaptor.getInput1().getType()); + ShapedType resultTy = cast(reshape.getType()); bool isDynamic = !operandTy.hasStaticShape(); if (isDynamic && operandTy.getRank() != 1) { @@ -211,8 +211,8 @@ LogicalResult matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - ShapedType operandTy = adaptor.getInput1().getType().cast(); - ShapedType resultTy = reshape.getType().template cast(); + ShapedType operandTy = cast(adaptor.getInput1().getType()); + ShapedType resultTy = cast(reshape.getType()); bool isDynamic = !operandTy.hasStaticShape(); SmallVector intermediateShape; @@ -247,7 +247,7 @@ Value input = adaptor.getInput(); SmallVector strides, sizes; ArrayRef starts = sliceOp.getStart(); - strides.resize(sliceOp.getType().template cast().getRank(), 1); + strides.resize(cast(sliceOp.getType()).getRank(), 1); SmallVector dynSizes; for (const auto &i : llvm::enumerate(sliceOp.getSize())) { @@ -284,7 +284,7 @@ auto input = padOp.getInput1(); auto padding = padOp.getPadding(); - ShapedType inputTy = input.getType().cast(); + ShapedType inputTy = cast(input.getType()); Type elementTy = inputTy.getElementType(); int64_t rank = inputTy.getRank(); @@ -297,11 +297,11 @@ loc, padOp.getPadConst(), ValueRange({})); } else { TypedAttr constantAttr; - if (elementTy.isa()) { + if (isa(elementTy)) { constantAttr = rewriter.getFloatAttr(elementTy, 0.0); - } else if (elementTy.isa() && !padOp.getQuantizationInfo()) { + } else if (isa(elementTy) && !padOp.getQuantizationInfo()) { constantAttr = rewriter.getIntegerAttr(elementTy, 0); - } else if (elementTy.isa() && padOp.getQuantizationInfo()) { + } else if (isa(elementTy) && padOp.getQuantizationInfo()) { int64_t value = padOp.getQuantizationInfo()->getInputZp(); constantAttr = rewriter.getIntegerAttr(elementTy, value); } @@ -355,8 +355,8 @@ LogicalResult matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto inputType = op.getOperand(0).getType().template cast(); - auto resultType = op.getType().dyn_cast(); + auto inputType = cast(op.getOperand(0).getType()); + auto resultType = dyn_cast(op.getType()); Location loc = op.getLoc(); int axis = op.getAxis(); diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -123,7 +123,7 @@ // constant stride. static std::optional getMemrefConstantHorizontalStride(ShapedType type) { - auto memrefType = type.dyn_cast(); + auto memrefType = dyn_cast(type); if (!memrefType) return false; // If the memref is 0 or 1D the horizontal stride is 0. @@ -193,10 +193,10 @@ /// Return true if the constant is a splat to a 2D vector so that it can be /// converted to a MMA constant matrix op. static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { - auto vecType = constantOp.getType().dyn_cast(); + auto vecType = dyn_cast(constantOp.getType()); if (!vecType || vecType.getRank() != 2) return false; - return constantOp.getValue().isa(); + return isa(constantOp.getValue()); } /// Return true if this is a broadcast from scalar to a 2D vector. @@ -268,11 +268,11 @@ // matrixB and matrixC operands. vector.extract_strided_slice op // is not supported on registers containing matrixA operands. if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B) - return (op->getResult(0).getType().cast() == - (*contractOp).getRhs().getType().cast()); + return (cast(op->getResult(0).getType()) == + cast((*contractOp).getRhs().getType())); if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C) - return (op->getResult(0).getType().cast() == - (*contractOp).getAcc().getType().cast()); + return (cast(op->getResult(0).getType()) == + cast((*contractOp).getAcc().getType())); return false; } @@ -344,11 +344,11 @@ bool useNvGpu) { auto hasVectorDest = [](Operation *op) { return llvm::any_of(op->getResultTypes(), - [](Type t) { return t.isa(); }); + [](Type t) { return isa(t); }); }; auto hasVectorSrc = [](Operation *op) { return llvm::any_of(op->getOperandTypes(), - [](Type t) { return t.isa(); }); + [](Type t) { return isa(t); }); }; SetVector opToConvert; op->walk([&](vector::ContractionOp contract) { @@ -448,8 +448,8 @@ (extOp = source.getDefiningOp())) { source = extOp->getOperand(0); resultType = - VectorType::get(resultType.cast().getShape(), - source.getType().cast().getElementType()); + VectorType::get(cast(resultType).getShape(), + cast(source.getType()).getElementType()); } auto transferReadOp = source.getDefiningOp(); @@ -553,7 +553,7 @@ bool isSignedExtend = isa(user); if (isSignedExtend || isa(user)) { elType = IntegerType::get( - op.getContext(), elType.cast().getWidth(), + op.getContext(), cast(elType).getWidth(), isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned); mappingResult = user->getResult(0); fragType = inferFragType(user); @@ -610,7 +610,7 @@ SmallVector shape{regInfo.numRegistersPerFragment, regInfo.elementsPerRegister}; Type elType = regInfo.registerLLVMType; - if (auto vecType = elType.dyn_cast()) + if (auto vecType = dyn_cast(elType)) elType = vecType.getElementType(); return VectorType::get(shape, elType); } @@ -637,7 +637,7 @@ } VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); - auto dense = op.getValue().dyn_cast(); + auto dense = dyn_cast(op.getValue()); if (!dense) { LLVM_DEBUG(DBGS() << "not a splat\n"); return rewriter.notifyMatchFailure(op, "not a splat"); @@ -782,7 +782,7 @@ // If we are not transposing, then we can use vectorized loads. Otherwise, we // must load each element individually. if (!isTransposeLoad) { - if (!loadedElType.isa()) { + if (!isa(loadedElType)) { loadedElType = VectorType::get({1}, loadedElType); } @@ -805,7 +805,7 @@ rewriter.getI64ArrayAttr(i)); } } else { - if (auto vecType = loadedElType.dyn_cast()) { + if (auto vecType = dyn_cast(loadedElType)) { loadedElType = vecType.getElementType(); } for (int i = 0; i < vectorType.getShape()[0]; i++) { @@ -838,7 +838,7 @@ /// Return true if this is a shared memory memref type. static bool isSharedMemory(MemRefType type) { auto addressSpace = - type.getMemorySpace().dyn_cast_or_null(); + dyn_cast_or_null(type.getMemorySpace()); if (addressSpace && addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace()) return true; @@ -860,7 +860,7 @@ return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); bool isLdMatrixCompatible = - isSharedMemory(op.getSource().getType().cast()) && + isSharedMemory(cast(op.getSource().getType())) && nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128; VectorType vecTy = op.getVectorType(); @@ -929,7 +929,7 @@ static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl &results) { for (auto attr : arrayAttr) - results.push_back(attr.cast().getInt()); + results.push_back(cast(attr).getInt()); } static LogicalResult @@ -1041,9 +1041,9 @@ itC == valueMapping.end()) return rewriter.notifyMatchFailure(op, "no mapping"); Value opA = itA->second, opB = itB->second, opC = itC->second; - int64_t m = op.getLhs().getType().cast().getShape()[0]; - int64_t n = op.getRhs().getType().cast().getShape()[0]; - int64_t k = op.getLhs().getType().cast().getShape()[1]; + int64_t m = cast(op.getLhs().getType()).getShape()[0]; + int64_t n = cast(op.getRhs().getType()).getShape()[0]; + int64_t k = cast(op.getLhs().getType()).getShape()[1]; Value matmul = rewriter.create( op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k})); valueMapping[op.getResult()] = matmul; @@ -1060,11 +1060,11 @@ assert(constantSupportsMMAMatrixType(op)); auto splat = - op.getValue().cast().getSplatValue(); + cast(op.getValue()).getSplatValue(); auto scalarConstant = rewriter.create(op.getLoc(), splat.getType(), splat); const char *fragType = inferFragType(op); - auto vecType = op.getType().cast(); + auto vecType = cast(op.getType()); gpu::MMAMatrixType type = gpu::MMAMatrixType::get( vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); auto matrix = rewriter.create( diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -256,8 +256,8 @@ return failure(); // Resolve address. - auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) - .template cast(); + auto vtype = cast( + this->typeConverter->convertType(loadOrStoreOp.getVectorType())); Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), adaptor.getIndices(), rewriter); Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype, @@ -277,7 +277,7 @@ LogicalResult matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - MemRefType memRefType = gather.getBaseType().dyn_cast(); + MemRefType memRefType = dyn_cast(gather.getBaseType()); assert(memRefType && "The base should be bufferized"); if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) @@ -296,7 +296,7 @@ auto llvmNDVectorTy = adaptor.getIndexVec().getType(); // Handle the simple case of 1-D vector. - if (!llvmNDVectorTy.isa()) { + if (!isa(llvmNDVectorTy)) { auto vType = gather.getVectorType(); // Resolve address. Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), @@ -501,7 +501,7 @@ static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - auto floatType = llvmType.cast(); + auto floatType = cast(llvmType); return rewriter.create( loc, llvmType, rewriter.getFloatAttr( @@ -513,7 +513,7 @@ static Value createReductionNeutralValue(ReductionNeutralFPMax neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - auto floatType = llvmType.cast(); + auto floatType = cast(llvmType); return rewriter.create( loc, llvmType, rewriter.getFloatAttr( @@ -585,9 +585,9 @@ /// with vector types. static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, Value rhs, bool isMin) { - auto floatType = getElementTypeOrSelf(lhs.getType()).cast(); + auto floatType = cast(getElementTypeOrSelf(lhs.getType())); Type i1Type = builder.getI1Type(); - if (auto vecType = lhs.getType().dyn_cast()) + if (auto vecType = dyn_cast(lhs.getType())) i1Type = VectorType::get(vecType.getShape(), i1Type); Value cmp = builder.create( loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, @@ -768,7 +768,7 @@ return success(); } - if (!eltType.isa()) + if (!isa(eltType)) return failure(); // Floating-point reductions: add/mul/min/max @@ -966,14 +966,14 @@ // For all other cases, insert the individual values individually. int64_t v1Dim = v1Type.getDimSize(0); Type eltType; - if (auto arrayType = llvmType.dyn_cast()) + if (auto arrayType = dyn_cast(llvmType)) eltType = arrayType.getElementType(); else - eltType = llvmType.cast().getElementType(); + eltType = cast(llvmType).getElementType(); Value insert = rewriter.create(loc, llvmType); int64_t insPos = 0; for (const auto &en : llvm::enumerate(maskArrayAttr)) { - int64_t extPos = en.value().cast().getInt(); + int64_t extPos = cast(en.value()).getInt(); Value value = adaptor.getV1(); if (extPos >= v1Dim) { extPos -= v1Dim; @@ -1046,7 +1046,7 @@ } // One-shot extraction of vector from array (only requires extractvalue). - if (resultType.isa()) { + if (isa(resultType)) { SmallVector indices; for (auto idx : positionArrayAttr.getAsRange()) indices.push_back(idx.getInt()); @@ -1062,13 +1062,13 @@ if (positionAttrs.size() > 1) { SmallVector nMinusOnePosition; for (auto idx : positionAttrs.drop_back()) - nMinusOnePosition.push_back(idx.cast().getInt()); + nMinusOnePosition.push_back(cast(idx).getInt()); extracted = rewriter.create(loc, extracted, nMinusOnePosition); } // Remaining extraction of element from 1-D LLVM vector - auto position = positionAttrs.back().cast(); + auto position = cast(positionAttrs.back()); auto i64Type = IntegerType::get(rewriter.getContext(), 64); auto constant = rewriter.create(loc, i64Type, position); extracted = @@ -1169,7 +1169,7 @@ } // One-shot insertion of a vector into an array (only requires insertvalue). - if (sourceType.isa()) { + if (isa(sourceType)) { Value inserted = rewriter.create( loc, adaptor.getDest(), adaptor.getSource(), LLVM::convertArrayToIndices(positionArrayAttr)); @@ -1180,7 +1180,7 @@ // Potential extraction of 1-D vector from array. Value extracted = adaptor.getDest(); auto positionAttrs = positionArrayAttr.getValue(); - auto position = positionAttrs.back().cast(); + auto position = cast(positionAttrs.back()); auto oneDVectorType = destVectorType; if (positionAttrs.size() > 1) { oneDVectorType = reducedVectorTypeBack(destVectorType); @@ -1333,7 +1333,7 @@ ConversionPatternRewriter &rewriter) const override { auto loc = castOp->getLoc(); MemRefType sourceMemRefType = - castOp.getOperand().getType().cast(); + cast(castOp.getOperand().getType()); MemRefType targetMemRefType = castOp.getType(); // Only static shape casts supported atm. @@ -1342,13 +1342,13 @@ return failure(); auto llvmSourceDescriptorTy = - adaptor.getOperands()[0].getType().dyn_cast(); + dyn_cast(adaptor.getOperands()[0].getType()); if (!llvmSourceDescriptorTy) return failure(); MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); - auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) - .dyn_cast_or_null(); + auto llvmTargetDescriptorTy = dyn_cast_or_null( + typeConverter->convertType(targetMemRefType)); if (!llvmTargetDescriptorTy) return failure(); @@ -1418,7 +1418,7 @@ LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override { auto dstType = op.getType(); - if (dstType.getRank() != 1 || !dstType.cast().isScalable()) + if (dstType.getRank() != 1 || !cast(dstType).isScalable()) return failure(); IntegerType idxType = force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); @@ -1465,7 +1465,7 @@ // Make sure element type has runtime support. PrintConversion conversion = PrintConversion::None; - VectorType vectorType = printType.dyn_cast(); + VectorType vectorType = dyn_cast(printType); Type eltType = vectorType ? vectorType.getElementType() : printType; auto parent = printOp->getParentOfType(); Operation *printer; @@ -1481,7 +1481,7 @@ printer = LLVM::lookupOrCreatePrintBF16Fn(parent); } else if (eltType.isIndex()) { printer = LLVM::lookupOrCreatePrintU64Fn(parent); - } else if (auto intTy = eltType.dyn_cast()) { + } else if (auto intTy = dyn_cast(eltType)) { // Integers need a zero or sign extension on the operand // (depending on the source type) as well as a signed or // unsigned print method. Up to 64-bit is supported. @@ -1536,7 +1536,7 @@ void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, Value value, Type type, Operation *printer, int64_t rank, PrintConversion conversion) const { - VectorType vectorType = type.dyn_cast(); + VectorType vectorType = dyn_cast(type); Location loc = op->getLoc(); if (!vectorType) { assert(rank == 0 && "The scalar case expects rank == 0"); @@ -1610,7 +1610,7 @@ LogicalResult matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType resultType = splatOp.getType().cast(); + VectorType resultType = cast(splatOp.getType()); if (resultType.getRank() > 1) return failure(); @@ -1633,7 +1633,7 @@ auto v = rewriter.create( splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); - int64_t width = splatOp.getType().cast().getDimSize(0); + int64_t width = cast(splatOp.getType()).getDimSize(0); SmallVector zeroValues(width, 0); // Shuffle the value across the desired number of elements. diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -258,7 +258,7 @@ /// Return true if this transfer op operates on a source tensor. template static bool isTensorOp(OpTy xferOp) { - if (xferOp.getShapedType().template isa()) { + if (isa(xferOp.getShapedType())) { if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) { // TransferWriteOps on tensors have a result. assert(xferOp->getNumResults() > 0); @@ -314,7 +314,7 @@ /// /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> static MemRefType unpackOneDim(MemRefType type) { - auto vectorType = type.getElementType().dyn_cast(); + auto vectorType = dyn_cast(type.getElementType()); auto memrefShape = type.getShape(); SmallVector newMemrefShape; newMemrefShape.append(memrefShape.begin(), memrefShape.end()); @@ -408,8 +408,8 @@ getXferIndices(b, xferOp, iv, xferIndices); Location loc = xferOp.getLoc(); - auto bufferType = buffer.getType().dyn_cast(); - auto vecType = bufferType.getElementType().dyn_cast(); + auto bufferType = dyn_cast(buffer.getType()); + auto vecType = dyn_cast(bufferType.getElementType()); auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); auto newXferOp = b.create( loc, vecType, xferOp.getSource(), xferIndices, @@ -432,8 +432,8 @@ storeIndices.push_back(iv); Location loc = xferOp.getLoc(); - auto bufferType = buffer.getType().dyn_cast(); - auto vecType = bufferType.getElementType().dyn_cast(); + auto bufferType = dyn_cast(buffer.getType()); + auto vecType = dyn_cast(bufferType.getElementType()); auto vec = b.create(loc, vecType, xferOp.getPadding()); b.create(loc, vec, buffer, storeIndices); @@ -698,7 +698,7 @@ // Find and cast data buffer. How the buffer can be found depends on OpTy. ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter); auto dataBuffer = Strategy::getBuffer(xferOp); - auto dataBufferType = dataBuffer.getType().template dyn_cast(); + auto dataBufferType = dyn_cast(dataBuffer.getType()); auto castedDataType = unpackOneDim(dataBufferType); auto castedDataBuffer = locB.create(castedDataType, dataBuffer); @@ -707,8 +707,7 @@ Value castedMaskBuffer; if (xferOp.getMask()) { auto maskBuffer = getMaskBuffer(xferOp); - auto maskBufferType = - maskBuffer.getType().template dyn_cast(); + auto maskBufferType = dyn_cast(maskBuffer.getType()); if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) { // Do not unpack a dimension of the mask, if: // * To-be-unpacked transfer op dimension is a broadcast. @@ -889,7 +888,7 @@ SmallVector &indices) const { if (auto insertOp = getInsertOp(xferOp)) { for (Attribute attr : insertOp.getPosition()) - indices.push_back(attr.dyn_cast().getInt()); + indices.push_back(dyn_cast(attr).getInt()); } } @@ -908,7 +907,7 @@ auto insertOp = getInsertOp(xferOp); auto vec = getResultVector(xferOp, rewriter); - auto vecType = vec.getType().dyn_cast(); + auto vecType = dyn_cast(vec.getType()); auto xferVecType = xferOp.getVectorType(); auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(), xferVecType.getElementType()); @@ -1016,7 +1015,7 @@ SmallVector &indices) const { if (auto extractOp = getExtractOp(xferOp)) { for (Attribute attr : extractOp.getPosition()) - indices.push_back(attr.dyn_cast().getInt()); + indices.push_back(dyn_cast(attr).getInt()); } } @@ -1235,7 +1234,7 @@ if (xferOp.getTransferRank() == 0) return failure(); auto map = xferOp.getPermutationMap(); - auto memRefType = xferOp.getShapedType().template dyn_cast(); + auto memRefType = dyn_cast(xferOp.getShapedType()); if (!memRefType) return failure(); diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -43,7 +43,7 @@ // TODO: This does not take into account any memory layout or widening // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even // though in practice it will likely be stored as in a 4xi64 vector register. - if (auto vectorType = type.dyn_cast()) + if (auto vectorType = dyn_cast(type)) return vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); return type.getIntOrFloatBitWidth(); } @@ -95,7 +95,7 @@ if (!resultType) return failure(); - if (resultType.isa()) { + if (isa(resultType)) { rewriter.replaceOp(castOp, adaptor.getSource()); return success(); } @@ -116,7 +116,7 @@ matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only support extracting a scalar value now. - VectorType resultVectorType = extractOp.getType().dyn_cast(); + VectorType resultVectorType = dyn_cast(extractOp.getType()); if (resultVectorType && resultVectorType.getNumElements() > 1) return failure(); @@ -124,7 +124,7 @@ if (!dstType) return failure(); - if (adaptor.getVector().getType().isa()) { + if (isa(adaptor.getVector().getType())) { rewriter.replaceOp(extractOp, adaptor.getVector()); return success(); } @@ -156,7 +156,7 @@ Value srcVector = adaptor.getOperands().front(); // Extract vector<1xT> case. - if (dstType.isa()) { + if (isa(dstType)) { rewriter.replaceOpWithNewOp(extractOp, srcVector, offset); return success(); @@ -203,7 +203,7 @@ return success(); } - if (insertOp.getSourceType().isa() || + if (isa(insertOp.getSourceType()) || !spirv::CompositeType::isValid(insertOp.getDestVectorType())) return failure(); int32_t id = getFirstIntValue(insertOp.getPosition()); @@ -224,7 +224,7 @@ if (!resultType) return failure(); - if (adaptor.getVector().getType().isa()) { + if (isa(adaptor.getVector().getType())) { rewriter.replaceOp(extractOp, adaptor.getVector()); return success(); } @@ -252,7 +252,7 @@ if (!vectorType) return failure(); - if (vectorType.isa()) { + if (isa(vectorType)) { rewriter.replaceOp(insertOp, adaptor.getSource()); return success(); } @@ -285,18 +285,17 @@ return failure(); uint64_t offset = getFirstIntValue(insertOp.getOffsets()); - if (srcVector.getType().isa()) { - assert(!dstVector.getType().isa()); + if (isa(srcVector.getType())) { + assert(!isa(dstVector.getType())); rewriter.replaceOpWithNewOp( insertOp, dstVector.getType(), srcVector, dstVector, rewriter.getI32ArrayAttr(offset)); return success(); } - uint64_t totalSize = - dstVector.getType().cast().getNumElements(); + uint64_t totalSize = cast(dstVector.getType()).getNumElements(); uint64_t insertSize = - srcVector.getType().cast().getNumElements(); + cast(srcVector.getType()).getNumElements(); SmallVector indices(totalSize); std::iota(indices.begin(), indices.end(), 0); @@ -324,7 +323,7 @@ if (!resultType) return failure(); - auto srcVectorType = adaptor.getVector().getType().dyn_cast(); + auto srcVectorType = dyn_cast(adaptor.getVector().getType()); if (!srcVectorType || srcVectorType.getRank() != 1) return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source"); @@ -393,10 +392,10 @@ Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) return failure(); - if (dstType.isa()) { + if (isa(dstType)) { rewriter.replaceOp(op, adaptor.getInput()); } else { - auto dstVecType = dstType.cast(); + auto dstVecType = cast(dstType); SmallVector source(dstVecType.getNumElements(), adaptor.getInput()); rewriter.replaceOpWithNewOp(op, dstType, @@ -422,7 +421,7 @@ if (oldSourceType.getNumElements() > 1) { SmallVector components = llvm::to_vector<4>( llvm::map_range(shuffleOp.getMask(), [](Attribute attr) -> int32_t { - return attr.cast().getValue().getZExtValue(); + return cast(attr).getValue().getZExtValue(); })); rewriter.replaceOpWithNewOp( shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(), diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp --- a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp @@ -65,7 +65,7 @@ newAttrs.push_back(attr); continue; } - auto segmentAttr = attr.getValue().cast(); + auto segmentAttr = cast(attr.getValue()); MLIRContext *context = segmentAttr.getContext(); DenseI32ArrayAttr newSegments; switch (action) { @@ -128,7 +128,7 @@ Value prevLoadForCompare = prevLoad; Value atomicResForCompare = atomicRes; - if (auto floatDataTy = dataType.dyn_cast()) { + if (auto floatDataTy = dyn_cast(dataType)) { Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth()); prevLoadForCompare = rewriter.create(loc, equivInt, prevLoad); diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp --- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp @@ -136,7 +136,7 @@ bool mlir::affine::isLoopMemoryParallel(AffineForOp forOp) { // Any memref-typed iteration arguments are treated as serializing. if (llvm::any_of(forOp.getResultTypes(), - [](Type type) { return type.isa(); })) + [](Type type) { return isa(type); })) return false; // Collect all load and store ops in loop nest rooted at 'forOp'. diff --git a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp --- a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp @@ -162,7 +162,7 @@ /// conservative. static bool isAccessIndexInvariant(Value iv, Value index) { assert(isAffineForInductionVar(iv) && "iv must be a AffineForOp"); - assert(index.getType().isa() && "index must be of IndexType"); + assert(isa(index.getType()) && "index must be of IndexType"); SmallVector affineApplyOps; getReachableAffineApplyOps({index}, affineApplyOps); @@ -262,7 +262,7 @@ template static bool isVectorElement(LoadOrStoreOp memoryOp) { auto memRefType = memoryOp.getMemRefType(); - return memRefType.getElementType().template isa(); + return isa(memRefType.getElementType()); } using VectorizableOpFun = std::function; diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -190,7 +190,7 @@ if (!hasEdge(srcId, dstId, value)) { outEdges[srcId].push_back({dstId, value}); inEdges[dstId].push_back({srcId, value}); - if (value.getType().isa()) + if (isa(value.getType())) memrefEdgeCount[value]++; } } @@ -200,7 +200,7 @@ Value value) { assert(inEdges.count(dstId) > 0); assert(outEdges.count(srcId) > 0); - if (value.getType().isa()) { + if (isa(value.getType())) { assert(memrefEdgeCount.count(value) > 0); memrefEdgeCount[value]--; } @@ -289,7 +289,7 @@ // By definition of edge, if the edge value is a non-memref value, // then the dependence is between a graph node which defines an SSA value // and another graph node which uses the SSA value. - if (!edge.value.getType().isa()) + if (!isa(edge.value.getType())) definingNodes.insert(edge.id); } @@ -473,7 +473,7 @@ ArrayRef edges, const std::function &callback) { for (const auto &edge : edges) { // Skip if 'edge' is not a memref dependence edge. - if (!edge.value.getType().isa()) + if (!isa(edge.value.getType())) continue; assert(nodes.count(edge.id) > 0); // Skip if 'edge.id' is not a loop nest. @@ -808,13 +808,13 @@ } unsigned MemRefRegion::getRank() const { - return memref.getType().cast().getRank(); + return cast(memref.getType()).getRank(); } std::optional MemRefRegion::getConstantBoundingSizeAndShape( SmallVectorImpl *shape, std::vector> *lbs, SmallVectorImpl *lbDivisors) const { - auto memRefType = memref.getType().cast(); + auto memRefType = cast(memref.getType()); unsigned rank = memRefType.getRank(); if (shape) shape->reserve(rank); @@ -875,7 +875,7 @@ void MemRefRegion::getLowerAndUpperBound(unsigned pos, AffineMap &lbMap, AffineMap &ubMap) const { assert(pos < cst.getNumDimVars() && "invalid position"); - auto memRefType = memref.getType().cast(); + auto memRefType = cast(memref.getType()); unsigned rank = memRefType.getRank(); assert(rank == cst.getNumDimVars() && "inconsistent memref region"); @@ -1049,7 +1049,7 @@ // to guard against potential over-approximation from projection. // TODO: Support dynamic memref dimensions. if (addMemRefDimBounds) { - auto memRefType = memref.getType().cast(); + auto memRefType = cast(memref.getType()); for (unsigned r = 0; r < rank; r++) { cst.addBound(BoundType::LB, /*pos=*/r, /*value=*/0); if (memRefType.isDynamicDim(r)) @@ -1071,7 +1071,7 @@ unsigned sizeInBits; if (elementType.isIntOrFloat()) { sizeInBits = elementType.getIntOrFloatBitWidth(); - } else if (auto vectorType = elementType.dyn_cast()) { + } else if (auto vectorType = dyn_cast(elementType)) { if (vectorType.getElementType().isIntOrFloat()) sizeInBits = vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); @@ -1085,7 +1085,7 @@ // Returns the size of the region. std::optional MemRefRegion::getRegionSize() { - auto memRefType = memref.getType().cast(); + auto memRefType = cast(memref.getType()); if (!memRefType.getLayout().isIdentity()) { LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); @@ -1119,7 +1119,7 @@ if (!memRefType.hasStaticShape()) return std::nullopt; auto elementType = memRefType.getElementType(); - if (!elementType.isIntOrFloat() && !elementType.isa()) + if (!elementType.isIntOrFloat() && !isa(elementType)) return std::nullopt; auto sizeInBytes = getMemRefIntOrFloatEltSizeInBytes(memRefType); @@ -1708,7 +1708,7 @@ } unsigned MemRefAccess::getRank() const { - return memref.getType().cast().getRank(); + return cast(memref.getType()).getRank(); } bool MemRefAccess::isStore() const { diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -289,7 +289,7 @@ // memref type. Call Op that returns one or more memref type results // is already taken care of, by the previous conditions. if (llvm::any_of(op.getOperandTypes(), - [&](Type t) { return t.isa(); })) { + [&](Type t) { return isa(t); })) { Node node(nextNodeId++, &op); nodes.insert({node.id, node}); } @@ -379,7 +379,7 @@ OpBuilder top(forInst->getParentRegion()); // Create new memref type based on slice bounds. auto oldMemRef = cast(srcStoreOpInst).getMemRef(); - auto oldMemRefType = oldMemRef.getType().cast(); + auto oldMemRefType = cast(oldMemRef.getType()); unsigned rank = oldMemRefType.getRank(); // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'. @@ -516,7 +516,7 @@ return WalkResult::advance(); for (Value v : op->getOperands()) // Collect memref values only. - if (v.getType().isa()) + if (isa(v.getType())) memRefValues.insert(v); return WalkResult::advance(); }); diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp --- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp @@ -88,7 +88,7 @@ return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({}); }; - auto oldMemRefType = oldMemRef.getType().cast(); + auto oldMemRefType = cast(oldMemRef.getType()); auto newMemRefType = doubleShape(oldMemRefType); // The double buffer is allocated right before 'forOp'. diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -100,9 +100,9 @@ SmallVector opsToSimplify; func.walk([&](Operation *op) { for (auto attr : op->getAttrs()) { - if (auto mapAttr = attr.getValue().dyn_cast()) + if (auto mapAttr = dyn_cast(attr.getValue())) simplifyAndUpdateAttribute(op, attr.getName(), mapAttr); - else if (auto setAttr = attr.getValue().dyn_cast()) + else if (auto setAttr = dyn_cast(attr.getValue())) simplifyAndUpdateAttribute(op, attr.getName(), setAttr); } diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -838,7 +838,7 @@ Value replacement) { assert(!valueVectorReplacement.contains(replaced) && "Vector replacement already registered"); - assert(replacement.getType().isa() && + assert(isa(replacement.getType()) && "Expected vector type in vector replacement"); valueVectorReplacement.map(replaced, replacement); } @@ -883,7 +883,7 @@ Value replacement) { assert(!valueScalarReplacement.contains(replaced) && "Scalar value replacement already registered"); - assert(!replacement.getType().isa() && + assert(!isa(replacement.getType()) && "Expected scalar type in scalar replacement"); valueScalarReplacement.map(replaced, replacement); } @@ -946,7 +946,7 @@ /// strategy on the scalar type. static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy) { - assert(!scalarTy.isa() && "Expected scalar type"); + assert(!isa(scalarTy) && "Expected scalar type"); return VectorType::get(strategy->vectorSizes, scalarTy); } @@ -1137,7 +1137,7 @@ // An vector operand that is not in the replacement map should never reach // this point. Reaching this point could mean that the code was already // vectorized and we shouldn't try to vectorize already vectorized code. - assert(!operand.getType().isa() && + assert(!isa(operand.getType()) && "Vector op not found in replacement map"); // Vectorize constant. diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -1852,7 +1852,7 @@ int64_t numEltPerStride = 1; int64_t stride = 1; for (int d = bufferShape.size() - 1; d >= 1; d--) { - int64_t dimSize = region.memref.getType().cast().getDimSize(d); + int64_t dimSize = cast(region.memref.getType()).getDimSize(d); stride *= dimSize; numEltPerStride *= bufferShape[d]; // A stride is needed only if the region has a shorter extent than the @@ -1891,7 +1891,7 @@ return ubMap.getNumInputs() == ubOperands.size(); })); - unsigned rank = memref.getType().cast().getRank(); + unsigned rank = cast(memref.getType()).getRank(); assert(lbMaps.size() == rank && "wrong number of lb maps"); assert(ubMaps.size() == rank && "wrong number of ub maps"); @@ -2003,7 +2003,7 @@ auto loc = region.loc; auto memref = region.memref; - auto memRefType = memref.getType().cast(); + auto memRefType = cast(memref.getType()); if (!memRefType.getLayout().isIdentity()) { LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); @@ -2276,7 +2276,7 @@ assert(false && "expected load or store op"); return false; } - auto memRefType = region->memref.getType().cast(); + auto memRefType = cast(region->memref.getType()); if (!memRefType.hasStaticShape()) return false; diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -1119,9 +1119,9 @@ ArrayRef extraIndices, AffineMap indexRemap, ArrayRef extraOperands, ArrayRef symbolOperands, bool allowNonDereferencingOps) { - unsigned newMemRefRank = newMemRef.getType().cast().getRank(); + unsigned newMemRefRank = cast(newMemRef.getType()).getRank(); (void)newMemRefRank; // unused in opt mode - unsigned oldMemRefRank = oldMemRef.getType().cast().getRank(); + unsigned oldMemRefRank = cast(oldMemRef.getType()).getRank(); (void)oldMemRefRank; // unused in opt mode if (indexRemap) { assert(indexRemap.getNumSymbols() == symbolOperands.size() && @@ -1134,8 +1134,8 @@ } // Assert same elemental type. - assert(oldMemRef.getType().cast().getElementType() == - newMemRef.getType().cast().getElementType()); + assert(cast(oldMemRef.getType()).getElementType() == + cast(newMemRef.getType()).getElementType()); SmallVector usePositions; for (const auto &opEntry : llvm::enumerate(op->getOperands())) { @@ -1172,7 +1172,7 @@ // Perform index rewrites for the dereferencing op and then replace the op NamedAttribute oldMapAttrPair = affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef); - AffineMap oldMap = oldMapAttrPair.getValue().cast().getValue(); + AffineMap oldMap = cast(oldMapAttrPair.getValue()).getValue(); unsigned oldMapNumInputs = oldMap.getNumInputs(); SmallVector oldMapOperands( op->operand_begin() + memRefOperandPos + 1, @@ -1294,9 +1294,9 @@ ArrayRef symbolOperands, Operation *domOpFilter, Operation *postDomOpFilter, bool allowNonDereferencingOps, bool replaceInDeallocOp) { - unsigned newMemRefRank = newMemRef.getType().cast().getRank(); + unsigned newMemRefRank = cast(newMemRef.getType()).getRank(); (void)newMemRefRank; // unused in opt mode - unsigned oldMemRefRank = oldMemRef.getType().cast().getRank(); + unsigned oldMemRefRank = cast(oldMemRef.getType()).getRank(); (void)oldMemRefRank; if (indexRemap) { assert(indexRemap.getNumSymbols() == symbolOperands.size() && @@ -1309,8 +1309,8 @@ } // Assert same elemental type. - assert(oldMemRef.getType().cast().getElementType() == - newMemRef.getType().cast().getElementType()); + assert(cast(oldMemRef.getType()).getElementType() == + cast(newMemRef.getType()).getElementType()); std::unique_ptr domInfo; std::unique_ptr postDomInfo; @@ -1734,7 +1734,7 @@ SmallVector> tileSizePos; (void)getTileSizePos(layoutMap, tileSizePos); if (newMemRefType.getNumDynamicDims() > 0 && !tileSizePos.empty()) { - MemRefType oldMemRefType = oldMemRef.getType().cast(); + MemRefType oldMemRefType = cast(oldMemRef.getType()); SmallVector newDynamicSizes; createNewDynamicSizes(oldMemRefType, newMemRefType, layoutMap, allocOp, b, newDynamicSizes); diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -34,7 +34,7 @@ return constantOp->emitError("could not infer memory space"); // Only ranked tensors are supported. - if (!constantOp.getType().isa()) + if (!isa(constantOp.getType())) return failure(); // Only constants inside a module are supported. @@ -58,7 +58,7 @@ bool isWritable(Operation *op, Value value, const AnalysisState &state) const { // Memory locations returned by memref::GetGlobalOp may not be written to. - assert(value.isa()); + assert(isa(value)); return false; } }; @@ -84,21 +84,21 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto castOp = cast(op); - auto resultTensorType = castOp.getType().cast(); + auto resultTensorType = cast(castOp.getType()); FailureOr source = getBuffer(rewriter, castOp.getIn(), options); if (failed(source)) return failure(); - auto sourceType = source->getType().cast(); + auto sourceType = cast(source->getType()); // Result type should have same layout and address space as the source type. BaseMemRefType resultType; - if (auto rankedMemRefType = sourceType.dyn_cast()) { + if (auto rankedMemRefType = dyn_cast(sourceType)) { resultType = MemRefType::get( rankedMemRefType.getShape(), resultTensorType.getElementType(), rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace()); } else { - auto unrankedMemrefType = sourceType.cast(); + auto unrankedMemrefType = cast(sourceType); resultType = UnrankedMemRefType::get(resultTensorType.getElementType(), unrankedMemrefType.getMemorySpace()); } diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp @@ -63,10 +63,10 @@ Location loc, Type type, const APInt &value) { TypedAttr attr; - if (auto intTy = type.dyn_cast()) { + if (auto intTy = dyn_cast(type)) { attr = rewriter.getIntegerAttr(type, value); } else { - auto vecTy = type.cast(); + auto vecTy = cast(type); attr = SplatElementsAttr::get(vecTy, value); } @@ -78,10 +78,10 @@ Location loc, Type type, int64_t value) { unsigned elementBitWidth = 0; - if (auto intTy = type.dyn_cast()) + if (auto intTy = dyn_cast(type)) elementBitWidth = intTy.getWidth(); else - elementBitWidth = type.cast().getElementTypeBitWidth(); + elementBitWidth = cast(type).getElementTypeBitWidth(); return createScalarOrSplatConstant(rewriter, loc, type, APInt(elementBitWidth, value)); @@ -95,7 +95,7 @@ static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value input, int64_t lastOffset) { - ArrayRef shape = input.getType().cast().getShape(); + ArrayRef shape = cast(input.getType()).getShape(); assert(lastOffset < shape.back() && "Offset out of bounds"); // Scalarize the result in case of 1D vectors. @@ -125,7 +125,7 @@ // `input` is a scalar, this is a noop. static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input) { - auto vecTy = input.getType().dyn_cast(); + auto vecTy = dyn_cast(input.getType()); if (!vecTy) return input; @@ -142,7 +142,7 @@ /// `input` is a scalar, this is a noop. static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input) { - auto vecTy = input.getType().dyn_cast(); + auto vecTy = dyn_cast(input.getType()); if (!vecTy) return input; @@ -159,11 +159,11 @@ static Value insertLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value source, Value dest, int64_t lastOffset) { - ArrayRef shape = dest.getType().cast().getShape(); + ArrayRef shape = cast(dest.getType()).getShape(); assert(lastOffset < shape.back() && "Offset out of bounds"); // Handle scalar source. - if (source.getType().isa()) + if (isa(source.getType())) return rewriter.create(loc, source, dest, lastOffset); SmallVector offsets(shape.size(), 0); @@ -215,14 +215,14 @@ unsigned newBitWidth = newType.getElementTypeBitWidth(); Attribute oldValue = op.getValueAttr(); - if (auto intAttr = oldValue.dyn_cast()) { + if (auto intAttr = dyn_cast(oldValue)) { auto [low, high] = getHalves(intAttr.getValue(), newBitWidth); auto newAttr = DenseElementsAttr::get(newType, {low, high}); rewriter.replaceOpWithNewOp(op, newAttr); return success(); } - if (auto splatAttr = oldValue.dyn_cast()) { + if (auto splatAttr = dyn_cast(oldValue)) { auto [low, high] = getHalves(splatAttr.getSplatValue(), newBitWidth); int64_t numSplatElems = splatAttr.getNumElements(); @@ -238,7 +238,7 @@ return success(); } - if (auto elemsAttr = oldValue.dyn_cast()) { + if (auto elemsAttr = dyn_cast(oldValue)) { int64_t numElems = elemsAttr.getNumElements(); SmallVector values; values.reserve(numElems * 2); @@ -527,9 +527,8 @@ Location loc = op->getLoc(); Type oldTy = op.getType(); - auto newTy = this->getTypeConverter() - ->convertType(oldTy) - .template dyn_cast_or_null(); + auto newTy = dyn_cast_or_null( + this->getTypeConverter()->convertType(oldTy)); if (!newTy) return rewriter.notifyMatchFailure( loc, llvm::formatv("unsupported type: {0}", op.getType())); @@ -549,11 +548,11 @@ /// Returns true iff the type is `index` or `vector<...index>`. static bool isIndexOrIndexVector(Type type) { - if (type.isa()) + if (isa(type)) return true; - if (auto vectorTy = type.dyn_cast()) - if (vectorTy.getElementType().isa()) + if (auto vectorTy = dyn_cast(type)) + if (isa(vectorTy.getElementType())) return true; return false; @@ -610,7 +609,7 @@ // Emit an index cast over the matching narrow type. Type narrowTy = rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth()); - if (auto vecTy = resultType.dyn_cast()) + if (auto vecTy = dyn_cast(resultType)) narrowTy = VectorType::get(vecTy.getShape(), narrowTy); // Sign or zero-extend the result. Let the matching conversion pattern @@ -1116,7 +1115,7 @@ // Vector case. addConversion([this](VectorType ty) -> std::optional { - auto intTy = ty.getElementType().dyn_cast(); + auto intTy = dyn_cast(ty.getElementType()); if (!intTy) return ty; diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp --- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp @@ -86,12 +86,12 @@ continue; } - assert(value.getType().cast().isDynamicDim(*dim) && + assert(cast(value.getType()).isDynamicDim(*dim) && "expected dynamic dim"); - if (value.getType().isa()) { + if (isa(value.getType())) { // A tensor dimension is used: generate a tensor.dim. operands.push_back(b.create(loc, value, *dim)); - } else if (value.getType().isa()) { + } else if (isa(value.getType())) { // A memref dimension is used: generate a memref.dim. operands.push_back(b.create(loc, value, *dim)); } else { diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -58,7 +58,7 @@ OpFoldResult ofr) { if (auto value = ofr.dyn_cast()) return value; - auto attr = ofr.dyn_cast().dyn_cast(); + auto attr = dyn_cast(ofr.dyn_cast()); assert(attr && "expect the op fold result casts to an integer attribute"); return b.create(loc, attr.getValue().getSExtValue()); } @@ -73,8 +73,8 @@ if (targetIsIndex ^ valueIsIndex) return b.create(loc, targetType, value); - auto targetIntegerType = targetType.dyn_cast(); - auto valueIntegerType = value.getType().dyn_cast(); + auto targetIntegerType = dyn_cast(targetType); + auto valueIntegerType = dyn_cast(value.getType()); assert(targetIntegerType && valueIntegerType && "unexpected cast between types other than integers and index"); assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); @@ -88,9 +88,9 @@ Type toType, bool isUnsignedCast) { if (operand.getType() == toType) return operand; - if (auto toIntType = toType.dyn_cast()) { + if (auto toIntType = dyn_cast(toType)) { // If operand is floating point, cast directly to the int type. - if (operand.getType().isa()) { + if (isa(operand.getType())) { if (isUnsignedCast) return b.create(loc, toType, operand); return b.create(loc, toType, operand); @@ -98,7 +98,7 @@ // Cast index operands directly to the int type. if (operand.getType().isIndex()) return b.create(loc, toType, operand); - if (auto fromIntType = operand.getType().dyn_cast()) { + if (auto fromIntType = dyn_cast(operand.getType())) { // Either extend or truncate. if (toIntType.getWidth() > fromIntType.getWidth()) { if (isUnsignedCast) @@ -108,15 +108,15 @@ if (toIntType.getWidth() < fromIntType.getWidth()) return b.create(loc, toType, operand); } - } else if (auto toFloatType = toType.dyn_cast()) { + } else if (auto toFloatType = dyn_cast(toType)) { // If operand is integer, cast directly to the float type. // Note that it is unclear how to cast from BF16<->FP16. - if (operand.getType().isa()) { + if (isa(operand.getType())) { if (isUnsignedCast) return b.create(loc, toFloatType, operand); return b.create(loc, toFloatType, operand); } - if (auto fromFloatType = operand.getType().dyn_cast()) { + if (auto fromFloatType = dyn_cast(operand.getType())) { if (toFloatType.getWidth() > fromFloatType.getWidth()) return b.create(loc, toFloatType, operand); if (toFloatType.getWidth() < fromFloatType.getWidth()) @@ -141,27 +141,27 @@ return b.create(loc, lhs, rhs); } Value ArithBuilder::add(Value lhs, Value rhs) { - if (lhs.getType().isa()) + if (isa(lhs.getType())) return b.create(loc, lhs, rhs); return b.create(loc, lhs, rhs); } Value ArithBuilder::sub(Value lhs, Value rhs) { - if (lhs.getType().isa()) + if (isa(lhs.getType())) return b.create(loc, lhs, rhs); return b.create(loc, lhs, rhs); } Value ArithBuilder::mul(Value lhs, Value rhs) { - if (lhs.getType().isa()) + if (isa(lhs.getType())) return b.create(loc, lhs, rhs); return b.create(loc, lhs, rhs); } Value ArithBuilder::sgt(Value lhs, Value rhs) { - if (lhs.getType().isa()) + if (isa(lhs.getType())) return b.create(loc, arith::CmpFPredicate::OGT, lhs, rhs); return b.create(loc, arith::CmpIPredicate::sgt, lhs, rhs); } Value ArithBuilder::slt(Value lhs, Value rhs) { - if (lhs.getType().isa()) + if (isa(lhs.getType())) return b.create(loc, arith::CmpFPredicate::OLT, lhs, rhs); return b.create(loc, arith::CmpIPredicate::slt, lhs, rhs); } diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp @@ -528,9 +528,9 @@ Operation *op = operand.getOwner(); Type type = operand.get().getType(); - bool isToken = type.isa(); - bool isGroup = type.isa(); - bool isValue = type.isa(); + bool isToken = isa(type); + bool isGroup = isa(type); + bool isValue = isa(type); // Drop reference after async token or group error check (coro await). if (auto await = dyn_cast(op)) diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -161,7 +161,7 @@ // We treat TokenType as state update marker to represent side-effects of // async computations - bool isStateful = func.getCallableResults().front().isa(); + bool isStateful = isa(func.getCallableResults().front()); std::optional retToken; if (isStateful) @@ -535,7 +535,7 @@ ConversionPatternRewriter &rewriter) const override { // We can only await on one the `AwaitableType` (for `await` it can be // a `token` or a `value`, for `await_all` it must be a `group`). - if (!op.getOperand().getType().template isa()) + if (!isa(op.getOperand().getType())) return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); // Check if await operation is inside the coroutine function. @@ -646,7 +646,7 @@ getReplacementValue(AwaitOp op, Value operand, ConversionPatternRewriter &rewriter) const override { // Load from the async value storage. - auto valueType = operand.getType().cast().getValueType(); + auto valueType = cast(operand.getType()).getValueType(); return rewriter.create(op->getLoc(), valueType, operand); } }; diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -59,7 +59,7 @@ // This transform op is currently restricted to ModuleOps and function ops. // Such ops are modified in-place. - transformResults.set(getTransformed().cast(), payloadOps); + transformResults.set(cast(getTransformed()), payloadOps); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp @@ -280,7 +280,7 @@ // defined in a non-dominated block or it is defined in the same block // but the current value is not dominated by the source value. if (!dominators.dominates(definingBlock, parentBlock) || - (definingBlock == parentBlock && value.isa())) { + (definingBlock == parentBlock && isa(value))) { toProcess.emplace_back(value, parentBlock); valuesToFree.insert(value); } else if (visitedValues.insert(std::make_tuple(value, definingBlock)) @@ -307,8 +307,8 @@ // Add new allocs and additional clone operations. for (Value value : valuesToFree) { - if (failed(value.isa() - ? introduceBlockArgCopy(value.cast()) + if (failed(isa(value) + ? introduceBlockArgCopy(cast(value)) : introduceValueCopyForRegionResult(value))) return failure(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp @@ -43,7 +43,7 @@ /// exceed the stack space. static bool defaultIsSmallAlloc(Value alloc, unsigned maximumSizeInBytes, unsigned maxRankOfAllocatedMemRef) { - auto type = alloc.getType().dyn_cast(); + auto type = dyn_cast(alloc.getType()); if (!type || !alloc.getDefiningOp()) return false; if (!type.hasStaticShape()) { @@ -355,7 +355,7 @@ OpBuilder builder(startOperation); Operation *allocOp = alloc.getDefiningOp(); Operation *alloca = builder.create( - alloc.getLoc(), alloc.getType().cast(), + alloc.getLoc(), cast(alloc.getType()), allocOp->getOperands(), allocOp->getAttrs()); // Replace the original alloc by a newly created alloca. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -53,7 +53,7 @@ SmallVector erasedResultTypes; BitVector erasedResultIndices(functionType.getNumResults()); for (const auto &resultType : llvm::enumerate(functionType.getResults())) { - if (auto memrefType = resultType.value().dyn_cast()) { + if (auto memrefType = dyn_cast(resultType.value())) { if (!hasStaticIdentityLayout(memrefType) && !hasFullyDynamicLayoutMap(memrefType)) { // Only buffers with static identity layout can be allocated. These can @@ -103,7 +103,7 @@ SmallVector copyIntoOutParams; SmallVector keepAsReturnOperands; for (Value operand : op.getOperands()) { - if (operand.getType().isa()) + if (isa(operand.getType())) copyIntoOutParams.push_back(operand); else keepAsReturnOperands.push_back(operand); @@ -137,7 +137,7 @@ SmallVector replaceWithNewCallResults; SmallVector replaceWithOutParams; for (OpResult result : op.getResults()) { - if (result.getType().isa()) + if (isa(result.getType())) replaceWithOutParams.push_back(result); else replaceWithNewCallResults.push_back(result); @@ -145,13 +145,13 @@ SmallVector outParams; OpBuilder builder(op); for (Value memref : replaceWithOutParams) { - if (!memref.getType().cast().hasStaticShape()) { + if (!cast(memref.getType()).hasStaticShape()) { op.emitError() << "cannot create out param for dynamically shaped result"; didFail = true; return; } - auto memrefType = memref.getType().cast(); + auto memrefType = cast(memref.getType()); auto allocType = MemRefType::get(memrefType.getShape(), memrefType.getElementType(), AffineMap(), memrefType.getMemorySpace()); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp @@ -68,7 +68,7 @@ [=](MemoryEffects::EffectInstance &it) { Value value = it.getValue(); return isa(it.getEffect()) && value && - value.isa() && + isa(value) && it.getResource() != SideEffects::AutomaticAllocationScopeResource::get(); }); @@ -149,7 +149,7 @@ FailureOr bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, Attribute memorySpace) { - auto type = constantOp.getType().cast(); + auto type = cast(constantOp.getType()); auto moduleOp = constantOp->getParentOfType(); if (!moduleOp) return failure(); @@ -185,14 +185,14 @@ : IntegerAttr(); BufferizeTypeConverter typeConverter; - auto memrefType = typeConverter.convertType(type).cast(); + auto memrefType = cast(typeConverter.convertType(type)); if (memorySpace) memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace); auto global = globalBuilder.create( constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), /*sym_visibility=*/globalBuilder.getStringAttr("private"), /*type=*/memrefType, - /*initial_value=*/constantOp.getValue().cast(), + /*initial_value=*/cast(constantOp.getValue()), /*constant=*/true, /*alignment=*/memrefAlignment); symbolTable.insert(global); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -44,7 +44,7 @@ static Value materializeToTensor(OpBuilder &builder, TensorType type, ValueRange inputs, Location loc) { assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); + assert(isa(inputs[0].getType())); return builder.create(loc, type, inputs[0]); } @@ -66,11 +66,11 @@ ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1 && "expected exactly one input"); - if (auto inputType = inputs[0].getType().dyn_cast()) { + if (auto inputType = dyn_cast(inputs[0].getType())) { // MemRef to MemRef cast. assert(inputType != type && "expected different types"); // Unranked to ranked and ranked to unranked casts must be explicit. - auto rankedDestType = type.dyn_cast(); + auto rankedDestType = dyn_cast(type); if (!rankedDestType) return nullptr; FailureOr replacement = @@ -80,7 +80,7 @@ return *replacement; } - if (inputs[0].getType().isa()) { + if (isa(inputs[0].getType())) { // Tensor to MemRef cast. return builder.create(loc, type, inputs[0]); } @@ -222,7 +222,7 @@ parseLayoutMapOption(unknownTypeConversion); opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace, const BufferizationOptions &options) { - auto tensorType = value.getType().cast(); + auto tensorType = cast(value.getType()); if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap) return bufferization::getMemRefTypeWithStaticIdentityLayout( tensorType, memorySpace); @@ -325,7 +325,7 @@ // BufferizableOpInterface-based Bufferization //===----------------------------------------------------------------------===// -static bool isaTensor(Type t) { return t.isa(); } +static bool isaTensor(Type t) { return isa(t); } /// Return true if the given op has a tensor result or a tensor operand. static bool hasTensorSemantics(Operation *op) { @@ -549,7 +549,7 @@ options.unknownTypeConverterFn = [](Value value, Attribute memorySpace, const BufferizationOptions &options) { return getMemRefTypeWithStaticIdentityLayout( - value.getType().cast(), memorySpace); + cast(value.getType()), memorySpace); }; options.opFilter.allowDialect(); return options; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -33,12 +33,12 @@ Operation *insertionPoint, const SmallVector &neededValues) { for (Value val : neededValues) { - if (auto bbArg = val.dyn_cast()) { + if (auto bbArg = dyn_cast(val)) { Block *owner = bbArg.getOwner(); if (!owner->findAncestorOpInBlock(*insertionPoint)) return false; } else { - auto opResult = val.cast(); + auto opResult = cast(val); if (!domInfo.dominates(opResult.getOwner(), insertionPoint)) return false; } @@ -75,7 +75,7 @@ // * in case of an OpResult: There must be at least one op right after the // defining op (the anchor op or one of its // parents). - if (auto bbArg = val.dyn_cast()) { + if (auto bbArg = dyn_cast(val)) { insertionPointCandidates.push_back( &bbArg.getOwner()->getOperations().front()); } else { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -60,7 +60,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, const BufferizationOptions &options) { auto tensorType = - funcOp.getFunctionType().getInput(index).dyn_cast(); + dyn_cast(funcOp.getFunctionType().getInput(index)); assert(tensorType && "expected TensorType"); BaseMemRefType memrefType = options.functionArgTypeConverterFn( @@ -71,7 +71,7 @@ if (!layoutAttr) return memrefType; - auto rankedMemrefType = memrefType.dyn_cast(); + auto rankedMemrefType = dyn_cast(memrefType); assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); return MemRefType::get( rankedMemrefType.getShape(), rankedMemrefType.getElementType(), @@ -224,7 +224,7 @@ for (const auto &it : llvm::enumerate(callOp.getResultTypes())) { unsigned returnValIdx = it.index(); Type returnType = it.value(); - if (!returnType.isa()) { + if (!isa(returnType)) { // Non-tensor values are returned. retValMapping[returnValIdx] = resultTypes.size(); resultTypes.push_back(returnType); @@ -242,7 +242,7 @@ Value tensorOperand = opOperand.get(); // Non-tensor operands are just copied. - if (!tensorOperand.getType().isa()) { + if (!isa(tensorOperand.getType())) { newOperands[idx] = tensorOperand; continue; } @@ -342,7 +342,7 @@ SmallVector argTypes; for (const auto &it : llvm::enumerate(funcType.getInputs())) { Type argType = it.value(); - if (auto tensorType = argType.dyn_cast()) { + if (auto tensorType = dyn_cast(argType)) { argTypes.push_back( getBufferizedFunctionArgType(funcOp, it.index(), options)); continue; @@ -356,7 +356,7 @@ if (funcOp.getBody().empty()) { SmallVector retTypes; for (Type resultType : funcType.getResults()) { - if (resultType.isa()) + if (isa(resultType)) return funcOp->emitError() << "cannot bufferize bodiless function " << "that returns a tensor"; retTypes.push_back(resultType); @@ -373,7 +373,7 @@ // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg. Block &frontBlock = funcOp.getBody().front(); for (BlockArgument &bbArg : frontBlock.getArguments()) { - auto tensorType = bbArg.getType().dyn_cast(); + auto tensorType = dyn_cast(bbArg.getType()); // Non-tensor types stay the same. if (!tensorType) continue; @@ -404,7 +404,7 @@ SmallVector returnValues; for (OpOperand &returnOperand : returnOp->getOpOperands()) { Value returnVal = returnOperand.get(); - auto tensorType = returnVal.getType().dyn_cast(); + auto tensorType = dyn_cast(returnVal.getType()); rewriter.setInsertionPoint(returnOp); // If not a tensor type just forward it. @@ -436,7 +436,7 @@ bool isWritable(Operation *op, Value value, const AnalysisState &state) const { auto funcOp = cast(op); - BlockArgument bbArg = value.dyn_cast(); + BlockArgument bbArg = dyn_cast(value); assert(bbArg && "expected BlockArgument"); // "bufferization.writable" overrides other writability decisions. This is diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -66,7 +66,7 @@ using namespace mlir; using namespace mlir::bufferization; -static bool isaTensor(Type t) { return t.isa(); } +static bool isaTensor(Type t) { return isa(t); } //===----------------------------------------------------------------------===// // Bufferization-specific attribute manipulation. @@ -85,11 +85,11 @@ SmallVector inPlaceVector; if (auto attr = op->getAttr(kInPlaceOperandsAttrName)) { inPlaceVector = SmallVector(llvm::to_vector<4>( - attr.cast().getAsValueRange())); + cast(attr).getAsValueRange())); } else { inPlaceVector = SmallVector(op->getNumOperands(), "none"); for (OpOperand &opOperand : op->getOpOperands()) - if (opOperand.get().getType().isa()) + if (isa(opOperand.get().getType())) inPlaceVector[opOperand.getOperandNumber()] = "false"; } inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false"; @@ -107,12 +107,12 @@ // Set up alias sets. op->walk([&](Operation *op) { for (Value v : op->getResults()) - if (v.getType().isa()) + if (isa(v.getType())) createAliasInfoEntry(v); for (Region &r : op->getRegions()) for (Block &b : r.getBlocks()) for (auto bbArg : b.getArguments()) - if (bbArg.getType().isa()) + if (isa(bbArg.getType())) createAliasInfoEntry(bbArg); }); @@ -121,7 +121,7 @@ if (!options.isOpAllowed(bufferizableOp)) return WalkResult::skip(); for (OpOperand &opOperand : bufferizableOp->getOpOperands()) - if (opOperand.get().getType().isa()) + if (isa(opOperand.get().getType())) if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) bufferizeInPlace(opOperand); return WalkResult::advance(); @@ -187,13 +187,13 @@ for (OpOperand &returnValOperand : returnOp->getOpOperands()) { Value returnVal = returnValOperand.get(); // Skip non-tensor values. - if (!returnVal.getType().isa()) + if (!isa(returnVal.getType())) continue; // Add all aliases of the returned value. But only the ones that are in // the same block. applyOnAliases(returnVal, [&](Value v) { - if (auto bbArg = v.dyn_cast()) { + if (auto bbArg = dyn_cast(v)) { if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp()) yieldedTensors.insert(bbArg); return; @@ -217,7 +217,7 @@ // Check all tensor OpResults. for (OpResult opResult : op->getOpResults()) { - if (!opResult.getType().isa()) + if (!isa(opResult.getType())) continue; // If there is no preceding definition, the tensor contents are @@ -259,7 +259,7 @@ return bufferizableOp.isWritable(value, *this); // Query BufferizableOpInterface to see if the BlockArgument is writable. - if (auto bbArg = value.dyn_cast()) + if (auto bbArg = dyn_cast(value)) if (auto bufferizableOp = getOptions().dynCastBufferizableOp(bbArg.getOwner()->getParentOp())) return bufferizableOp.isWritable(bbArg, *this); @@ -431,12 +431,12 @@ id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]"; readingOp->setAttr(readAttr, b.getUnitAttr()); - if (auto opResult = definition.dyn_cast()) { + if (auto opResult = dyn_cast(definition)) { std::string defAttr = id + "[DEF: result " + std::to_string(opResult.getResultNumber()) + "]"; opResult.getDefiningOp()->setAttr(defAttr, b.getUnitAttr()); } else { - auto bbArg = definition.cast(); + auto bbArg = cast(definition); std::string defAttr = id + "[DEF: bbArg " + std::to_string(bbArg.getArgNumber()) + "]"; bbArg.getOwner()->getParentOp()->setAttr(defAttr, b.getUnitAttr()); @@ -581,7 +581,7 @@ continue; } } else { - auto bbArg = definition.cast(); + auto bbArg = cast(definition); Block *block = bbArg.getOwner(); if (!block->findAncestorOpInBlock(*conflictingWritingOp)) { LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg " @@ -715,12 +715,12 @@ static int64_t counter = 0; OpBuilder b(value.getContext()); std::string id = "W_" + std::to_string(counter++); - if (auto opResult = value.dyn_cast()) { + if (auto opResult = dyn_cast(value)) { std::string attr = id + "[NOT-WRITABLE: result " + std::to_string(opResult.getResultNumber()) + "]"; opResult.getDefiningOp()->setAttr(attr, b.getUnitAttr()); } else { - auto bbArg = value.cast(); + auto bbArg = cast(value); std::string attr = id + "[NOT-WRITABLE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]"; bbArg.getOwner()->getParentOp()->setAttr(attr, b.getUnitAttr()); @@ -812,7 +812,7 @@ OneShotAnalysisState::analyzeSingleOp(Operation *op, const DominanceInfo &domInfo) { for (OpOperand &opOperand : op->getOpOperands()) - if (opOperand.get().getType().isa()) + if (isa(opOperand.get().getType())) if (failed(bufferizableInPlaceAnalysisImpl(opOperand, *this, domInfo))) return failure(); return success(); @@ -831,7 +831,7 @@ for (Operation *op : ops) { if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) { for (OpResult opResult : op->getOpResults()) { - if (!opResult.getType().isa()) + if (!isa(opResult.getType())) continue; AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult); if (aliases.getNumAliases() == 0) @@ -958,7 +958,7 @@ } for (OpOperand &opOperand : op->getOpOperands()) { - if (opOperand.get().getType().isa()) { + if (isa(opOperand.get().getType())) { if (wouldCreateReadAfterWriteInterference( opOperand, domInfo, state, /*checkConsistencyOnly=*/true)) { @@ -984,7 +984,7 @@ // Add __inplace_operands_attr__. op->walk([&](Operation *op) { for (OpOperand &opOperand : op->getOpOperands()) - if (opOperand.get().getType().isa()) + if (isa(opOperand.get().getType())) setInPlaceOpOperand(opOperand, state.isInPlace(opOperand)); }); } @@ -1031,12 +1031,12 @@ for (OpOperand &returnValOperand : returnOp->getOpOperands()) { Value returnVal = returnValOperand.get(); // Skip non-tensor values. - if (!returnVal.getType().isa()) + if (!isa(returnVal.getType())) continue; bool foundEquivValue = false; state.applyOnEquivalenceClass(returnVal, [&](Value equivVal) { - if (auto bbArg = equivVal.dyn_cast()) { + if (auto bbArg = dyn_cast(equivVal)) { Operation *definingOp = bbArg.getOwner()->getParentOp(); if (definingOp->isProperAncestor(returnOp)) foundEquivValue = true; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -109,9 +109,9 @@ SmallVector equivBbArgs; if (op->hasAttr(kEquivalentArgsAttr)) { - auto attr = op->getAttr(kEquivalentArgsAttr).cast(); + auto attr = cast(op->getAttr(kEquivalentArgsAttr)); equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) { - return a.cast().getValue().getSExtValue(); + return cast(a).getValue().getSExtValue(); })); } else { equivBbArgs.append(op->getNumOperands(), -1); @@ -132,10 +132,10 @@ // return value may alias with any tensor bbArg. FunctionType type = funcOp.getFunctionType(); for (const auto &inputIt : llvm::enumerate(type.getInputs())) { - if (!inputIt.value().isa()) + if (!isa(inputIt.value())) continue; for (const auto &resultIt : llvm::enumerate(type.getResults())) { - if (!resultIt.value().isa()) + if (!isa(resultIt.value())) continue; int64_t returnIdx = resultIt.index(); int64_t bbArgIdx = inputIt.index(); @@ -150,9 +150,9 @@ assert(returnOp && "expected func with single return op"); for (OpOperand &returnVal : returnOp->getOpOperands()) - if (returnVal.get().getType().isa()) + if (isa(returnVal.get().getType())) for (BlockArgument bbArg : funcOp.getArguments()) - if (bbArg.getType().isa()) { + if (isa(bbArg.getType())) { int64_t returnIdx = returnVal.getOperandNumber(); int64_t bbArgIdx = bbArg.getArgNumber(); if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) { @@ -193,7 +193,7 @@ for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e; ++idx) { // Skip non-tensor arguments. - if (!funcOp.getFunctionType().getInput(idx).isa()) + if (!isa(funcOp.getFunctionType().getInput(idx))) continue; bool isRead; bool isWritten; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp @@ -58,7 +58,7 @@ for (OpOperand &opOperand : bufferizableOp->getOpOperands()) { Value operand = opOperand.get(); // Skip non-tensor operands. - if (!operand.getType().isa()) + if (!isa(operand.getType())) continue; // Skip operands that do not bufferize to memory writes. if (!bufferizableOp.bufferizesToMemoryWrite(opOperand, state)) @@ -85,7 +85,7 @@ // Insert a tensor copy and replace all uses inside of repetitive regions. rewriter.setInsertionPoint(bufferizableOp); auto tensorCopy = rewriter.create( - bufferizableOp->getLoc(), operand.getType().cast(), + bufferizableOp->getLoc(), cast(operand.getType()), /*dynamicSizes=*/ValueRange(), /*copy=*/operand, /*memory_space=*/IntegerAttr()); for (OpOperand *use : usesInsideRegion) @@ -137,7 +137,7 @@ SmallVector escapeAttrValue; bool foundTensorResult = false; for (OpResult opResult : op->getOpResults()) { - if (!opResult.getType().isa() || + if (!isa(opResult.getType()) || !bufferizableOp.bufferizesToAllocation(opResult)) { escapeAttrValue.push_back(false); continue; diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -257,19 +257,19 @@ bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) { - return attr.isa(); + return isa(attr); }); bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) { - return attr.isa(); + return isa(attr); }); bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) { - return attr.isa(); + return isa(attr); }); bool hasLinearMapping = llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) { - return attr.isa(); + return isa(attr); }); int64_t countMappingTypes = 0; countMappingTypes += hasBlockMapping ? 1 : 0; @@ -520,7 +520,7 @@ ArrayRef{forallMappingAttrs}.take_front( forallOp.getInductionVars().size()))) { Value peIdOp = mappingIdOps[static_cast( - dim.cast().getMappingId())]; + cast(dim).getMappingId())]; bvm.map(iv, peIdOp); } diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp @@ -214,7 +214,7 @@ /// Returns an accumulator factory that creates an op specified by opName. AccumulatorFactory getFactory(gpu::AllReduceOperation opName) { - bool isFloatingPoint = valueType.isa(); + bool isFloatingPoint = isa(valueType); switch (opName) { case gpu::AllReduceOperation::ADD: return isFloatingPoint ? getFactory() diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp @@ -158,9 +158,9 @@ transform(executeOp.getResultTypes(), std::back_inserter(resultTypes), [](Type type) { // Extract value type from !async.value. - if (auto valueType = type.dyn_cast()) + if (auto valueType = dyn_cast(type)) return valueType.getValueType(); - assert(type.isa() && "expected token type"); + assert(isa(type) && "expected token type"); return type; }); transform(results, std::back_inserter(resultTypes), @@ -305,9 +305,9 @@ executeOp.getBodyResults(), [](OpResult result) { if (result.use_empty() || result.hasOneUse()) return false; - auto valueType = result.getType().dyn_cast(); + auto valueType = dyn_cast(result.getType()); return valueType && - valueType.getValueType().isa(); + isa(valueType.getValueType()); }); if (multiUseResults.empty()) return; diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -338,7 +338,7 @@ if (!resultAttr) return failure(); - dataLayoutSpec = resultAttr.dyn_cast(); + dataLayoutSpec = dyn_cast(resultAttr); if (!dataLayoutSpec) return failure(); } @@ -410,7 +410,7 @@ SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) { for (SymbolTable::SymbolUse symbolUse : *symbolUses) { StringRef symbolName = - symbolUse.getSymbolRef().cast().getValue(); + cast(symbolUse.getSymbolRef()).getValue(); if (symbolTable.lookup(symbolName)) continue; diff --git a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp --- a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp @@ -30,7 +30,7 @@ /// single-iteration loops. Maps the innermost loops to thread dimensions, in /// reverse order to enable access coalescing in the innermost loop. static void insertCopyLoops(ImplicitLocOpBuilder &b, Value from, Value to) { - auto memRefType = from.getType().cast(); + auto memRefType = cast(from.getType()); auto rank = memRefType.getRank(); SmallVector lbs, ubs, steps; @@ -121,8 +121,8 @@ /// pointed to by "from". In case a smaller block would be sufficient, the /// caller can create a subview of the memref and promote it instead. static void insertCopies(Region ®ion, Location loc, Value from, Value to) { - auto fromType = from.getType().cast(); - auto toType = to.getType().cast(); + auto fromType = cast(from.getType()); + auto toType = cast(to.getType()); (void)fromType; (void)toType; assert(fromType.getShape() == toType.getShape()); @@ -143,7 +143,7 @@ /// copies will be inserted in the beginning and in the end of the function. void mlir::promoteToWorkgroupMemory(GPUFuncOp op, unsigned arg) { Value value = op.getArgument(arg); - auto type = value.getType().dyn_cast(); + auto type = dyn_cast(value.getType()); assert(type && type.hasStaticShape() && "can only promote memrefs"); // Get the type of the buffer in the workgroup memory. diff --git a/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp --- a/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp +++ b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp @@ -67,7 +67,7 @@ ConstraintVerifier &context) const { // Check that the base is the expected one. - auto dynAttr = attr.dyn_cast(); + auto dynAttr = dyn_cast(attr); if (!dynAttr || dynAttr.getAttrDef() != attrDef) { if (emitError) { StringRef dialectName = attrDef->getDialect()->getNamespace(); @@ -102,7 +102,7 @@ function_ref emitError, Attribute attr, ConstraintVerifier &context) const { // Check that the base is a TypeAttr. - auto typeAttr = attr.dyn_cast(); + auto typeAttr = dyn_cast(attr); if (!typeAttr) { if (emitError) return emitError() << "expected type, got attribute '" << attr; @@ -110,7 +110,7 @@ } // Check that the type base is the expected one. - auto dynType = typeAttr.getValue().dyn_cast(); + auto dynType = dyn_cast(typeAttr.getValue()); if (!dynType || dynType.getTypeDef() != typeDef) { if (emitError) { StringRef dialectName = typeDef->getDialect()->getNamespace(); diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp --- a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp @@ -25,11 +25,11 @@ /// Attempt to extract a filename for the given loc. static FileLineColLoc extractFileLoc(Location loc) { - if (auto fileLoc = loc.dyn_cast()) + if (auto fileLoc = dyn_cast(loc)) return fileLoc; - if (auto nameLoc = loc.dyn_cast()) + if (auto nameLoc = dyn_cast(loc)) return extractFileLoc(nameLoc.getChildLoc()); - if (auto opaqueLoc = loc.dyn_cast()) + if (auto opaqueLoc = dyn_cast(loc)) return extractFileLoc(opaqueLoc.getFallbackLocation()); return FileLineColLoc(); } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp @@ -607,7 +607,7 @@ return diag; Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position)); - if (getResult().getType().isa()) { + if (isa(getResult().getType())) { results.setValues(cast(getResult()), result); return DiagnosedSilenceableFailure::success(); } @@ -648,7 +648,7 @@ LogicalResult transform::MatchStructuredResultOp::verify() { if ((getAny() || getSingle()) ^ - getResult().getType().isa()) { + isa(getResult().getType())) { return emitOpError() << "expects either the any/single keyword or the type " "value handle result type"; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -87,7 +87,7 @@ SmallVector &result, ArrayRef ofrs) { for (OpFoldResult ofr : ofrs) { if (ofr.is()) { - if (!ofr.get().isa()) + if (!isa(ofr.get())) return transformOp.emitDefiniteFailure() << "expected IntegerAttr"; result.push_back(ofr); continue; @@ -155,7 +155,7 @@ llvm::map_range(state.getPayloadValues(getTarget()), [&](Value v) { return linalg::bufferizeToAllocation(rewriter, v, memorySpace); })); - results.setValues(getTransformed().cast(), transformed); + results.setValues(cast(getTransformed()), transformed); return DiagnosedSilenceableFailure::success(); } @@ -276,7 +276,7 @@ if (!sizesAttr) return parser.emitError(opLoc) << "expected '" << sizesAttrName << "' attribute"; - auto sizesArrayAttr = sizesAttr.dyn_cast(); + auto sizesArrayAttr = dyn_cast(sizesAttr); if (!sizesArrayAttr) return parser.emitError(opLoc) << "'" << sizesAttrName << "' attribute must be an array"; @@ -389,7 +389,7 @@ // Tile the producer. int64_t resultNumber = - sliceOpToTile.getSource().cast().getResultNumber(); + cast(sliceOpToTile.getSource()).getResultNumber(); LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); FailureOr tileAndFuseResult = @@ -411,10 +411,7 @@ // Replace the extract op. auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0], - sliceOpToTile->getResult(0) - .getType() - .cast() - .getShape()); + cast(sliceOpToTile->getResult(0).getType()).getShape()); assert(succeeded(maybeRankReduced) && "unexpected shape"); rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); return tileAndFuseResult->tiledOps; @@ -482,7 +479,7 @@ // Replace the use in the tileableProducer before tiling: clone, replace and // then tile. - int64_t resultNumber = pUse->get().cast().getResultNumber(); + int64_t resultNumber = cast(pUse->get()).getResultNumber(); LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); // Gather destination tensors. @@ -516,10 +513,7 @@ // Replace the extract op. auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0], - sliceOpToTile->getResult(0) - .getType() - .cast() - .getShape()); + cast(sliceOpToTile->getResult(0).getType()).getShape()); assert(succeeded(maybeRankReduced) && "unexpected shape"); rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); @@ -568,7 +562,7 @@ // TODO: Generalize to other type of ops. assert(!isa(use->getOwner()) && "Parallel insert slice is not a valid clone destination"); - unsigned resultNumber = use->get().cast().getResultNumber(); + unsigned resultNumber = cast(use->get()).getResultNumber(); LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); OpBuilder::InsertionGuard guard(rewriter); @@ -587,8 +581,7 @@ ArrayRef producerOps = state.getPayloadOps(getProducerOp()); // If nothing to fuse, propagate success. if (producerOps.empty()) { - results.set(getFusedOp().cast(), - SmallVector{}); + results.set(cast(getFusedOp()), SmallVector{}); return DiagnosedSilenceableFailure::success(); } ArrayRef containingOps = state.getPayloadOps(getContainingOp()); @@ -671,7 +664,7 @@ return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } - results.set(getFusedOp().cast(), fusedOps); + results.set(cast(getFusedOp()), fusedOps); return DiagnosedSilenceableFailure::success(); } @@ -865,7 +858,7 @@ }; payloadOps.front()->walk(matchFun); - results.set(getResult().cast(), res); + results.set(cast(getResult()), res); return DiagnosedSilenceableFailure::success(); } @@ -901,7 +894,7 @@ DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( LinalgOp target, transform::ApplyToEachResultList &results, TransformState &state) { - if (getLowSize().getType().isa()) { + if (isa(getLowSize().getType())) { if (target.hasDynamicShape()) { auto diag = emitSilenceableError() << "cannot compute parametric tile sizes for dynamically " @@ -923,7 +916,7 @@ spec->lowTileSize * spec->lowTripCount}), [&builder, this](int64_t value) { return builder.getIntegerAttr( - getLowSize().getType().cast().getType(), value); + cast(getLowSize().getType()).getType(), value); })); return DiagnosedSilenceableFailure::success(); } @@ -958,7 +951,7 @@ SmallVectorImpl &effects) { onlyReadsHandle(getTarget(), effects); producesHandle(getResults(), effects); - if (getLowSize().getType().isa()) + if (isa(getLowSize().getType())) onlyReadsPayload(effects); else modifiesPayload(effects); @@ -1006,7 +999,7 @@ ArrayRef targetOps = state.getPayloadOps(getTarget()); // If nothing to pack, propagate success. if (targetOps.empty()) { - transformResults.set(getPackedOp().cast(), {}); + transformResults.set(cast(getPackedOp()), {}); return DiagnosedSilenceableFailure::success(); } // Fail on multi-op handles. @@ -1036,7 +1029,7 @@ if (failed(maybeResult)) return emitDefiniteFailure("data tiling failed"); - transformResults.set(getPackedOp().cast(), + transformResults.set(cast(getPackedOp()), maybeResult->packedLinalgOp.getOperation()); return DiagnosedSilenceableFailure::success(); } @@ -1242,7 +1235,7 @@ } results.push_back(linalgOp); } - transformResults.set(getPackedOp().cast(), results); + transformResults.set(cast(getPackedOp()), results); return DiagnosedSilenceableFailure::success(); } @@ -1322,9 +1315,9 @@ ArrayRef linalgOps = state.getPayloadOps(getTargetLinalgOp()); // Step 1. If nothing to pack, propagate success. if (packOrUnpackOps.empty()) { - transformResults.set(getPackedOp().cast(), {}); - transformResults.set(getPackOp().cast(), {}); - transformResults.set(getUnPackOp().cast(), {}); + transformResults.set(cast(getPackedOp()), {}); + transformResults.set(cast(getPackOp()), {}); + transformResults.set(cast(getUnPackOp()), {}); return DiagnosedSilenceableFailure::success(); } @@ -1366,7 +1359,7 @@ if (unPackOp) { assert(!packOp && "packOp must be null on entry when unPackOp is not null"); OpOperand *packUse = linalgOp.getDpsInitOperand( - unPackOp.getSource().cast().getResultNumber()); + cast(unPackOp.getSource()).getResultNumber()); packOp = dyn_cast_or_null(packUse->get().getDefiningOp()); if (!packOp || !packOp.getResult().hasOneUse()) return emitSilenceableError() << "could not find matching pack op"; @@ -1400,14 +1393,14 @@ assert(succeeded(res) && "unexpected packTranspose failure"); // Step 4. Return results. - transformResults.set(getPackOp().cast(), {res->transposedPackOp}); - transformResults.set(getPackedOp().cast(), + transformResults.set(cast(getPackOp()), {res->transposedPackOp}); + transformResults.set(cast(getPackedOp()), {res->transposedLinalgOp}); if (unPackOp) { - transformResults.set(getUnPackOp().cast(), + transformResults.set(cast(getUnPackOp()), {res->transposedUnPackOp}); } else { - transformResults.set(getUnPackOp().cast(), {}); + transformResults.set(cast(getUnPackOp()), {}); } return DiagnosedSilenceableFailure::success(); @@ -1430,14 +1423,14 @@ SmallVector paddingValues; for (auto const &it : llvm::zip(getPaddingValues(), target->getOperandTypes())) { - auto attr = std::get<0>(it).dyn_cast(); + auto attr = dyn_cast(std::get<0>(it)); if (!attr) { emitOpError("expects padding values to be typed attributes"); return DiagnosedSilenceableFailure::definiteFailure(); } Type elementType = getElementTypeOrSelf(std::get<1>(it)); // Try to parse string attributes to obtain an attribute of element type. - if (auto stringAttr = attr.dyn_cast()) { + if (auto stringAttr = dyn_cast(attr)) { auto parsedAttr = dyn_cast_if_present( parseAttribute(stringAttr, getContext(), elementType, /*numRead=*/nullptr, /*isKnownNullTerminated=*/true)); @@ -1462,9 +1455,9 @@ // Extract the transpose vectors. SmallVector> transposePaddings; - for (Attribute transposeVector : getTransposePaddings().cast()) + for (Attribute transposeVector : cast(getTransposePaddings())) transposePaddings.push_back( - extractFromI64ArrayAttr(transposeVector.cast())); + extractFromI64ArrayAttr(cast(transposeVector))); TrackingListener listener(state, *this); IRRewriter rewriter(getContext(), &listener); @@ -1549,13 +1542,13 @@ return emitDefiniteFailure() << "could not build packing loop nest"; if (result->clonedLoopIvs.empty()) { - transformResults.set(getPackingLoop().cast(), + transformResults.set(cast(getPackingLoop()), result->hoistedPadOp.getOperation()); return DiagnosedSilenceableFailure::success(); } auto outerPackedLoop = scf::getForInductionVarOwner(result->clonedLoopIvs.front()); - transformResults.set(getPackingLoop().cast(), + transformResults.set(cast(getPackingLoop()), outerPackedLoop.getOperation()); return DiagnosedSilenceableFailure::success(); } @@ -1643,7 +1636,7 @@ if (mapping.size() > 1) return emitDefaultDefiniteFailure(target); - auto addressSpace = mapping[0].cast(); + auto addressSpace = cast(mapping[0]); if (addressSpace.getAddressSpace() == gpu::GPUDialect::getWorkgroupAddressSpace()) { @@ -1711,7 +1704,7 @@ rewriter.replaceOp(target, replacement->getResults()); replacements.push_back(replacement); } - transformResults.set(getReplacement().cast(), replacements); + transformResults.set(cast(getReplacement()), replacements); return DiagnosedSilenceableFailure::success(); } @@ -1828,7 +1821,7 @@ splitPoints.reserve(payload.size()); if (getDynamicSplitPoint()) { auto diag = DiagnosedSilenceableFailure::success(); - if (getDynamicSplitPoint().getType().isa()) { + if (isa(getDynamicSplitPoint().getType())) { splitPoints = llvm::to_vector(llvm::map_range( state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) { if (op->getNumResults() != 1 || @@ -1909,8 +1902,8 @@ return diag; } - results.set(getFirst().cast(), first); - results.set(getSecond().cast(), second); + results.set(cast(getFirst()), first); + results.set(cast(getSecond()), second); return DiagnosedSilenceableFailure::success(); } @@ -2212,12 +2205,12 @@ dynamicSizeProducers.reserve(getDynamicSizes().size()); paramSizes.reserve(getDynamicSizes().size()); for (Value transformValue : getDynamicSizes()) { - if (transformValue.getType().isa()) { + if (isa(transformValue.getType())) { dynamicSizeProducers.push_back({}); ArrayRef params = state.getParams(transformValue); paramSizes.push_back( llvm::to_vector(llvm::map_range(params, [](Attribute attr) { - return attr.cast().getValue().getSExtValue(); + return cast(attr).getValue().getSExtValue(); }))); if (paramSizes.back().size() != targets.size()) { @@ -2247,7 +2240,7 @@ for (Operation *op : dynamicSizeProducers.back()) { if (op->getNumResults() == 1 && - op->getResult(0).getType().isa()) + isa(op->getResult(0).getType())) continue; DiagnosedSilenceableFailure diag = @@ -2283,7 +2276,7 @@ for (OpFoldResult ofr : getMixedSizes()) { if (auto attr = ofr.dyn_cast()) { sizes.push_back(b.create( - getLoc(), attr.cast().getInt())); + getLoc(), cast(attr).getInt())); continue; } ArrayRef dynamicSizes = dynamicSizeProducers[dynamicIdx]; @@ -2320,9 +2313,9 @@ loops[en2.index()].push_back(en2.value()); } - transformResults.set(getTiledLinalgOp().cast(), tiled); + transformResults.set(cast(getTiledLinalgOp()), tiled); for (const auto &en : llvm::enumerate(loops)) - transformResults.set(getLoops()[en.index()].cast(), en.value()); + transformResults.set(cast(getLoops()[en.index()]), en.value()); return DiagnosedSilenceableFailure::success(); } @@ -2582,8 +2575,8 @@ tiledOps.push_back(tilingResult.tiledOp); } - transformResults.set(getForallOp().cast(), tileOps); - transformResults.set(getTiledOp().cast(), tiledOps); + transformResults.set(cast(getForallOp()), tileOps); + transformResults.set(cast(getTiledOp()), tiledOps); return DiagnosedSilenceableFailure::success(); } @@ -2678,7 +2671,7 @@ for (Operation *op : dynamicSizeProducers.back()) { if (op->getNumResults() == 1 && - op->getResult(0).getType().isa()) + isa(op->getResult(0).getType())) continue; DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected sizes to be produced by ops " @@ -2712,7 +2705,7 @@ for (OpFoldResult ofr : getMixedSizes()) { if (auto attr = ofr.dyn_cast()) { sizes.push_back(b.create( - getLoc(), attr.cast().getInt())); + getLoc(), cast(attr).getInt())); } else { sizes.push_back( dynamicSizeProducers[dynamicIdx++][index]->getResult(0)); @@ -2737,9 +2730,9 @@ loops[en2.index()].push_back(en2.value()); } - transformResults.set(getTiledLinalgOp().cast(), tiled); + transformResults.set(cast(getTiledLinalgOp()), tiled); for (const auto &en : llvm::enumerate(loops)) - transformResults.set(getLoops()[en.index()].cast(), en.value()); + transformResults.set(cast(getLoops()[en.index()]), en.value()); return DiagnosedSilenceableFailure::success(); } @@ -2899,7 +2892,7 @@ for (OpFoldResult sz : getMixedVectorSizes()) { if (sz.is()) { auto attr = sz.get(); - vectorSizes.push_back(attr.cast().getInt()); + vectorSizes.push_back(cast(attr).getInt()); continue; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp @@ -64,20 +64,20 @@ if (genericOp.getNumDpsInits() != 1) return failure(); - auto outputType = genericOp.getResultTypes().front().dyn_cast(); + auto outputType = dyn_cast(genericOp.getResultTypes().front()); // Require the output types to be static given that we are generating // constants. if (!outputType || !outputType.hasStaticShape()) return failure(); if (!llvm::all_of(genericOp.getInputs(), [](Value input) { - return input.getType().isa(); + return isa(input.getType()); })) return failure(); // Make sure all element types are the same. auto getOperandElementType = [](Value value) { - return value.getType().cast().getElementType(); + return cast(value.getType()).getElementType(); }; if (!llvm::all_equal( llvm::map_range(genericOp->getOperands(), getOperandElementType))) @@ -138,7 +138,7 @@ // unify the following cases but they have lifetime as the MLIRContext. SmallVector intOutputValues; SmallVector fpOutputValues; - if (elementType.template isa()) + if (isa(elementType)) fpOutputValues.resize(numElements, APFloat(0.f)); else intOutputValues.resize(numElements); @@ -174,7 +174,7 @@ auto inputShapes = llvm::to_vector<4>( llvm::map_range(genericOp.getInputs(), [](Value value) { - return value.getType().cast().getShape(); + return cast(value.getType()).getShape(); })); // Given a `linearIndex`, remap it to a linear index to access linalg op @@ -205,7 +205,7 @@ } }; - bool isFloat = elementType.isa(); + bool isFloat = isa(elementType); if (isFloat) { SmallVector> inFpRanges; for (int i = 0; i < numInputs; ++i) @@ -282,7 +282,7 @@ // The yield op should return the block argument corresponds to the input. for (Value yieldVal : yieldOp.getValues()) { - auto yieldArg = yieldVal.dyn_cast(); + auto yieldArg = dyn_cast(yieldVal); if (!yieldArg || yieldArg.getOwner() != &body) return nullptr; if (yieldArg.getArgNumber() != 0) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp @@ -29,7 +29,7 @@ } static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) { - bool isInt = x.getType().isa(); + bool isInt = isa(x.getType()); if (isInt) return builder.create(loc, x, y); return builder.create(loc, x, y); @@ -42,7 +42,7 @@ convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false); Value yConvert = convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false); - if (accType.isa()) + if (isa(accType)) return builder.create(loc, xConvert, yConvert); return builder.create(loc, xConvert, yConvert); } @@ -74,9 +74,9 @@ FailureOr> rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { - auto inputType = convOp.getInputs()[0].getType().cast(); - auto filterType = convOp.getInputs()[1].getType().cast(); - auto outputType = convOp.getOutputs()[0].getType().cast(); + auto inputType = cast(convOp.getInputs()[0].getType()); + auto filterType = cast(convOp.getInputs()[1].getType()); + auto outputType = cast(convOp.getOutputs()[0].getType()); if (!filterType.hasStaticShape()) return rewriter.notifyMatchFailure( @@ -210,9 +210,9 @@ FailureOr> rewriteInIm2Col(RewriterBase &rewriter, linalg::DepthwiseConv2DNhwcHwcOp convOp) { - auto inputType = convOp.getInputs()[0].getType().cast(); - auto filterType = convOp.getInputs()[1].getType().cast(); - auto outputType = convOp.getOutputs()[0].getType().cast(); + auto inputType = cast(convOp.getInputs()[0].getType()); + auto filterType = cast(convOp.getInputs()[1].getType()); + auto outputType = cast(convOp.getOutputs()[0].getType()); if (!filterType.hasStaticShape()) return rewriter.notifyMatchFailure( @@ -230,7 +230,7 @@ Location loc = convOp.getLoc(); auto transposeOperand = [&](Value operand, ArrayRef indices) { - auto operandTensorType = operand.getType().cast(); + auto operandTensorType = cast(operand.getType()); auto nloops = indices.size(); ArrayRef inputShape = operandTensorType.getShape(); @@ -272,7 +272,7 @@ Value inputT = transposeOperand(input, {0, 3, 1, 2}); Value filterT = transposeOperand(filter, {2, 0, 1}); ArrayRef filterTShape = - filterT.getType().cast().getShape(); + cast(filterT.getType()).getShape(); ArrayRef outputShape = outputType.getShape(); int n = outputShape[0]; @@ -360,9 +360,9 @@ FailureOr> rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { - auto inputType = convOp.getInputs()[0].getType().cast(); - auto filterType = convOp.getInputs()[1].getType().cast(); - auto outputType = convOp.getOutputs()[0].getType().cast(); + auto inputType = cast(convOp.getInputs()[0].getType()); + auto filterType = cast(convOp.getInputs()[1].getType()); + auto outputType = cast(convOp.getOutputs()[0].getType()); if (!filterType.hasStaticShape()) return rewriter.notifyMatchFailure( diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -66,12 +66,12 @@ Attribute constYieldedValue; // Is the yielded value a bbArg defined outside of the PadOp? bool outsideBbArg = - yieldedValue.isa() && - yieldedValue.cast().getOwner()->getParentOp() != + isa(yieldedValue) && + cast(yieldedValue).getOwner()->getParentOp() != padOp.getOperation(); // Is the yielded value an OpResult defined outside of the PadOp? bool outsideOpResult = - yieldedValue.isa() && + isa(yieldedValue) && yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation(); bool invariantYieldedValue = outsideBbArg || outsideOpResult; if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) { @@ -120,19 +120,19 @@ static SmallVector reifyOrComputeDynamicSizes(OpBuilder &b, Value value) { - auto tensorType = value.getType().cast(); + auto tensorType = cast(value.getType()); if (tensorType.hasStaticShape()) return {}; // Try to reify dynamic sizes. ReifiedRankedShapedTypeDims reifiedShape; - if (value.isa() && + if (isa(value) && succeeded(reifyResultShapes(b, value.getDefiningOp(), reifiedShape))) { SmallVector dynSizes; for (int64_t i = 0; i < tensorType.getRank(); ++i) { if (tensorType.isDynamicDim(i)) dynSizes.push_back( - reifiedShape[value.cast().getResultNumber()][i] + reifiedShape[cast(value).getResultNumber()][i] .get()); } return dynSizes; @@ -153,12 +153,12 @@ Value value, Attribute memorySpace = {}) { OpBuilder::InsertionGuard g(rewriter); - auto tensorType = value.getType().cast(); + auto tensorType = cast(value.getType()); // Create buffer allocation. - auto memrefType = bufferization::getMemRefTypeWithStaticIdentityLayout( - tensorType, memorySpace) - .cast(); + auto memrefType = + cast(bufferization::getMemRefTypeWithStaticIdentityLayout( + tensorType, memorySpace)); SmallVector dynamicSizes = reifyOrComputeDynamicSizes(rewriter, value); Value alloc = rewriter.create(loc, memrefType, dynamicSizes); @@ -206,7 +206,7 @@ RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) { Location loc = fromElementsOp.getLoc(); RankedTensorType tensorType = - fromElementsOp.getType().cast(); + cast(fromElementsOp.getType()); auto shape = tensorType.getShape(); // Create tensor.empty. @@ -247,7 +247,7 @@ return failure(); Location loc = generateOp.getLoc(); - RankedTensorType tensorType = generateOp.getType().cast(); + RankedTensorType tensorType = cast(generateOp.getType()); // Create tensor.empty. auto emptyOp = @@ -339,7 +339,7 @@ llvm::map_range(value.getUses(), [](OpOperand &use) { return &use; })); OpBuilder::InsertionGuard g(rewriter); - if (auto bbArg = value.dyn_cast()) { + if (auto bbArg = dyn_cast(value)) { rewriter.setInsertionPointToStart(bbArg.getOwner()); } else { rewriter.setInsertionPointAfter(value.getDefiningOp()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -640,7 +640,7 @@ auto loc = genericOp.getLoc(); Value unPackDest = producerUnPackOp.getDest(); auto genericOutType = - genericOp.getDpsInitOperand(0)->get().getType().cast(); + cast(genericOp.getDpsInitOperand(0)->get().getType()); if (producerUnPackOp.getDestType() != genericOutType || !genericOutType.hasStaticShape()) { unPackDest = tensor::UnPackOp::createDestinationTensor( diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp @@ -132,12 +132,12 @@ static Value getZero(OpBuilder &b, Location loc, Type elementType) { assert(elementType.isIntOrIndexOrFloat() && "expected scalar type while computing zero value"); - if (elementType.isa()) + if (isa(elementType)) return b.create(loc, 0, elementType); if (elementType.isIndex()) return b.create(loc, 0); // Assume float. - auto floatType = elementType.cast(); + auto floatType = cast(elementType); return b.create( loc, APFloat::getZero(floatType.getFloatSemantics()), floatType); } @@ -179,7 +179,7 @@ if (resultNumber) { newInitValues.push_back( genericOp.getDpsInitOperand(*resultNumber)->get()); - OpResult result = genericOp.getResult(*resultNumber).cast(); + OpResult result = cast(genericOp.getResult(*resultNumber)); newResultTypes.push_back(result.getType()); peeledGenericOpIndexingMaps.push_back( genericOp.getIndexingMapMatchingResult(result)); @@ -231,7 +231,7 @@ })); for (auto resultNum : llvm::seq(origNumResults, peeledGenericOpNumResults)) { - OpResult result = peeledGenericOp.getResult(resultNum).cast(); + OpResult result = cast(peeledGenericOp.getResult(resultNum)); indexingMaps.push_back( peeledGenericOp.getIndexingMapMatchingResult(result)); } @@ -348,7 +348,7 @@ /// the peeled operation. SmallVector replacements; for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) { - OpResult opr = yieldValue.value().dyn_cast(); + OpResult opr = dyn_cast(yieldValue.value()); if (!opr || opr.getOwner() != peeledScalarOperation) replacements.push_back(residualGenericOp.getResult(yieldValue.index())); else diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -32,7 +32,7 @@ ValueRange inputs, Location loc) { assert(inputs.size() == 1); auto inputType = inputs[0].getType(); - if (inputType.isa()) + if (isa(inputType)) return nullptr; // A detensored value is converted back by creating a new tensor from its @@ -320,9 +320,9 @@ // * Add the argument to blockArgsToDetensor. // * Walk the use-def chain backwards to add each predecessor's // terminator-operands corresponding to currentItem to workList. - if (currentItem.dyn_cast()) { + if (dyn_cast(currentItem)) { BlockArgument currentItemBlockArgument = - currentItem.cast(); + cast(currentItem); Block *ownerBlock = currentItemBlockArgument.getOwner(); // Function arguments are not detensored/converted. diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -308,7 +308,7 @@ for (OpOperand *op : candidates) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfterValue(op->get()); - auto elemType = op->get().getType().cast().getElementType(); + auto elemType = cast(op->get().getType()).getElementType(); auto empty = rewriter.create( loc, tensor::createDimValues(rewriter, loc, op->get()), elemType); @@ -387,7 +387,7 @@ // Early return for memrefs with affine maps to represent that we will always // leave them unchanged. Type actualType = opOperand->get().getType(); - if (auto memref = actualType.dyn_cast()) { + if (auto memref = dyn_cast(actualType)) { if (!memref.getLayout().isIdentity()) return std::nullopt; } @@ -437,7 +437,7 @@ ArrayRef reassociation, Location loc, PatternRewriter &rewriter) const { // There are no results for memref outputs. - auto origResultType = origOutput.getType().cast(); + auto origResultType = cast(origOutput.getType()); if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { unsigned rank = origResultType.getRank(); SmallVector offsets(rank, rewriter.getIndexAttr(0)); @@ -459,7 +459,7 @@ Value collapseValue(Value operand, ArrayRef targetShape, ArrayRef reassociation, Location loc, PatternRewriter &rewriter) const { - if (auto memrefType = operand.getType().dyn_cast()) { + if (auto memrefType = dyn_cast(operand.getType())) { if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { FailureOr rankReducingExtract = memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand, @@ -478,7 +478,7 @@ return rewriter.create(loc, targetType, operand, reassociation); } - if (auto tensorType = operand.getType().dyn_cast()) { + if (auto tensorType = dyn_cast(operand.getType())) { if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { FailureOr rankReducingExtract = tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand, @@ -502,7 +502,7 @@ PatternRewriter &rewriter) const override { // Skip the pattern if the op has any tensor with special encoding. if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) { - auto tensorType = type.dyn_cast(); + auto tensorType = dyn_cast(type); return tensorType && tensorType.getEncoding() != nullptr; })) return failure(); @@ -607,11 +607,10 @@ if (!reassociation || reassociation->size() == static_cast(resultType.getRank())) return failure(); - auto rankReducedType = + auto rankReducedType = cast( tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( reassociation->size(), sliceOp.getSourceType(), offsets, sizes, - strides) - .cast(); + strides)); Location loc = sliceOp.getLoc(); Value newSlice = rewriter.create( diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -87,7 +87,7 @@ // type. Producer must have full tensor semantics to avoid potential // aliasing between producer and consumer memrefs. if (!producer.hasTensorSemantics() || - !fusedOperand->get().getType().isa()) + !isa(fusedOperand->get().getType())) return false; // Verify that @@ -232,14 +232,14 @@ // forward the yield operand. auto producerYieldOp = cast(producerBlock.getTerminator()); unsigned producerResultNumber = - fusedOperand->get().cast().getResultNumber(); + cast(fusedOperand->get()).getResultNumber(); Value replacement = mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber)); // Sanity checks, if replacement is not already in the mapper then it must be // produced outside. if (replacement == producerYieldOp.getOperand(producerResultNumber)) { - if (auto bb = replacement.dyn_cast()) + if (auto bb = dyn_cast(replacement)) assert(bb.getOwner() != &producerBlock && "yielded block argument must have been mapped"); else @@ -278,7 +278,7 @@ OpOperand *fusedOperand) { assert(areElementwiseOpsFusable(fusedOperand) && "expected elementwise operation pre-conditions to pass"); - auto producerResult = fusedOperand->get().cast(); + auto producerResult = cast(fusedOperand->get()); auto producer = cast(producerResult.getOwner()); auto consumer = cast(fusedOperand->getOwner()); // TODO: allow fusing the producer of an output operand. @@ -357,7 +357,7 @@ fusedOutputOperands.push_back(opOperand->get()); fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand)); Type resultType = opOperand->get().getType(); - if (!resultType.isa()) + if (!isa(resultType)) fusedResultTypes.push_back(resultType); } @@ -512,7 +512,7 @@ return genericOp.hasTensorSemantics() && llvm::all_of(genericOp.getIndexingMaps().getValue(), [](Attribute attr) { - return attr.cast() + return cast(attr) .getValue() .isProjectedPermutation(); }) && @@ -776,7 +776,7 @@ continue; } if (auto opOperandType = - opOperand->get().getType().dyn_cast()) { + dyn_cast(opOperand->get().getType())) { AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); RankedTensorType expandedOperandType = getExpandedType(opOperandType, indexingMap, expansionInfo); @@ -805,7 +805,7 @@ SmallVector outputs; for (OpOperand *opOperand : genericOp.getDpsInitOperands()) { AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); - auto opOperandType = opOperand->get().getType().cast(); + auto opOperandType = cast(opOperand->get().getType()); RankedTensorType expandedOutputType = getExpandedType(opOperandType, indexingMap, expansionInfo); if (expandedOutputType != opOperand->get().getType()) { @@ -921,7 +921,7 @@ LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, PatternRewriter &rewriter) const override { // Fold only if all constraints of fusing with reshape by expansion are met. - auto producerResult = reshapeOp.getSrc().dyn_cast(); + auto producerResult = dyn_cast(reshapeOp.getSrc()); if (!producerResult) { return rewriter.notifyMatchFailure(reshapeOp, "source not produced by an operation"); @@ -959,8 +959,9 @@ // same type as the returns of the original generic op, the consumer reshape // op can be replaced by the source of the collapse_shape op that defines // the replacement. - Value reshapeReplacement = (*replacementValues) - [reshapeOp.getSrc().cast().getResultNumber()]; + Value reshapeReplacement = + (*replacementValues)[cast(reshapeOp.getSrc()) + .getResultNumber()]; if (auto collapseOp = reshapeReplacement.getDefiningOp()) { reshapeReplacement = collapseOp.getSrc(); @@ -1438,7 +1439,7 @@ .createLoopRanges(rewriter, genericOp.getLoc()); auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) { if (auto attr = ofr.dyn_cast()) - return attr.cast().getInt() == value; + return cast(attr).getInt() == value; llvm::APInt actual; return matchPattern(ofr.get(), m_ConstantInt(&actual)) && actual.getSExtValue() == value; @@ -1512,8 +1513,8 @@ Value collapsedOpResult = collapsedGenericOp->getResult(originalResult.index()); auto originalResultType = - originalResult.value().getType().cast(); - auto collapsedOpResultType = collapsedOpResult.getType().cast(); + cast(originalResult.value().getType()); + auto collapsedOpResultType = cast(collapsedOpResult.getType()); if (collapsedOpResultType.getRank() != originalResultType.getRank()) { AffineMap indexingMap = genericOp.getIndexingMapMatchingResult(originalResult.value()); @@ -1655,7 +1656,7 @@ return false; }; - auto resultValue = opOperand->get().dyn_cast(); + auto resultValue = dyn_cast(opOperand->get()); if (!def || !resultValue || !isScalarOrSplatConstantOp(def)) continue; @@ -1740,7 +1741,7 @@ for (OpOperand *opOperand : op.getDpsInitOperands()) { if (!op.payloadUsesValueFromOperand(opOperand)) { Value operandVal = opOperand->get(); - auto operandType = operandVal.getType().dyn_cast(); + auto operandType = dyn_cast(operandVal.getType()); if (!operandType) continue; @@ -1794,7 +1795,7 @@ fillFound = true; Value fillVal = fillOp.value(); auto resultType = - fillOp.result().getType().cast().getElementType(); + cast(fillOp.result().getType()).getElementType(); Value convertedVal = convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType, /*isUnsignedCast =*/false); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -28,7 +28,7 @@ // TODO: The conversion pattern can be made to work for `any_of` here, but // it's more complex as it requires tracking which operands are scalars. return llvm::all_of(op->getOperandTypes(), - [](Type type) { return type.isa(); }); + [](Type type) { return isa(type); }); } /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over @@ -67,7 +67,7 @@ // Extract static / dynamic shape mix from the first operand. Value firstOperand = operands.front(); - auto rankedTensorType = t.cast(); + auto rankedTensorType = cast(t); auto staticShape = llvm::to_vector<4>(rankedTensorType.getShape()); auto dynamicShape = linalg::createDynamicDimensions(b, loc, firstOperand); @@ -87,7 +87,7 @@ return rewriter.notifyMatchFailure( op, "requires elementwise op on ranked tensors"); - auto rank = op->getResult(0).getType().cast().getRank(); + auto rank = cast(op->getResult(0).getType()).getRank(); SmallVector indexingMaps( op->getNumResults() + op->getNumOperands(), rewriter.getMultiDimIdentityMap(rank)); @@ -104,7 +104,7 @@ [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { auto resultTypes = llvm::to_vector<6>( llvm::map_range(op->getResultTypes(), [](Type type) { - return type.cast().getElementType(); + return cast(type).getElementType(); })); auto *scalarOp = builder.create(loc, op->getName().getIdentifier(), diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp @@ -89,7 +89,7 @@ Location loc = genericOp.getLoc(); SmallVector newResultTypes; for (Value v : newOutputOperands) - if (v.getType().isa()) + if (isa(v.getType())) newResultTypes.push_back(v.getType()); auto newOp = rewriter.create( loc, newResultTypes, newInputOperands, newOutputOperands, diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp @@ -86,12 +86,12 @@ // result of the generic op. The low pad values are the offsets, the size of // the source is the size of the slice. // TODO: This insert/extract could be potentially made a utility method. - unsigned resultNumber = source.cast().getResultNumber(); + unsigned resultNumber = cast(source).getResultNumber(); SmallVector offsets = padOp.getMixedLowPad(); SmallVector sizes; sizes.reserve(offsets.size()); - for (const auto &shape : llvm::enumerate( - source.getType().cast().getShape())) { + for (const auto &shape : + llvm::enumerate(cast(source.getType()).getShape())) { if (ShapedType::isDynamic(shape.value())) { sizes.push_back( rewriter.create(loc, source, shape.index()) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -151,7 +151,7 @@ SmallVector resultTypes; resultTypes.reserve(producer->getNumResults()); for (OpOperand *operand : producer.getDpsInitOperands()) { - auto tensorType = operand->get().getType().dyn_cast(); + auto tensorType = dyn_cast(operand->get().getType()); if (!tensorType) continue; unsigned rank = tensorType.getRank(); @@ -210,20 +210,20 @@ // dependence tracking since the dependence tracking is similar to what is done // w.r.t to buffers. static void getProducerOfTensor(Value tensor, OpResult &opResult) { - if (!tensor.getType().isa()) + if (!isa(tensor.getType())) return; while (true) { LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor); if (auto linalgOp = tensor.getDefiningOp()) { - opResult = tensor.cast(); + opResult = cast(tensor); return; } if (auto sliceOp = tensor.getDefiningOp()) { tensor = sliceOp.getSource(); continue; } - if (auto blockArg = tensor.dyn_cast()) { + if (auto blockArg = dyn_cast(tensor)) { if (auto forOp = blockArg.getDefiningOp()) { tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber()); continue; diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -227,7 +227,7 @@ return {}; bbArgs.push_back(bbArg); OpOperand *iterArg = &tileLoop.getOpOperandForRegionIterArg(bbArg); - bbArg = iterArg->get().dyn_cast(); + bbArg = dyn_cast(iterArg->get()); } // Reverse the block arguments to order them from outer to inner. @@ -358,13 +358,13 @@ // Check if the producer is a LinalgOp possibly passed by iteration argument. OpOperand *iterArg = nullptr; - auto producerResult = sliceOp.getSource().dyn_cast(); - if (auto bbArg = sliceOp.getSource().dyn_cast()) { + auto producerResult = dyn_cast(sliceOp.getSource()); + if (auto bbArg = dyn_cast(sliceOp.getSource())) { iterArg = getTiedIterArg(bbArg); // Check the iteration argument may be used to pass in the producer output. if (!iterArg || hasOtherUses(bbArg, sliceOp)) return failure(); - producerResult = iterArg->get().dyn_cast(); + producerResult = dyn_cast(iterArg->get()); } if (!producerResult || !isa(producerResult.getOwner())) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -549,7 +549,7 @@ int paddedRank = paddedTensorType.getRank(); // Step 0. Populate bvm with opToHoist.getSource if relevant. - BlockArgument bbArg = opToHoist.getSource().dyn_cast(); + BlockArgument bbArg = dyn_cast(opToHoist.getSource()); while (bbArg) { auto forOp = dyn_cast(bbArg.getOwner()->getParentOp()); if (!forOp) @@ -558,7 +558,7 @@ break; OpOperand &operand = forOp.getOpOperandForRegionIterArg(bbArg); bvm.map(bbArg, operand.get()); - bbArg = operand.get().dyn_cast(); + bbArg = dyn_cast(operand.get()); } // Step 1. iteratively clone loops and push `hoistedPackedTensor`. @@ -753,9 +753,8 @@ if (!destOp) break; LLVM_DEBUG(DBGS() << "--step dest op: " << destOp << "\n"); - source = - destOp.getDpsInitOperand(source.cast().getResultNumber()) - ->get(); + source = destOp.getDpsInitOperand(cast(source).getResultNumber()) + ->get(); } LLVM_DEBUG(DBGS() << "--final source: " << source << "\n"); LLVM_DEBUG(DBGS() << "--expected source: " << expectedSource << "\n"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -86,7 +86,7 @@ [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); func.walk([&](vector::TransferReadOp transferRead) { - if (!transferRead.getShapedType().isa()) + if (!isa(transferRead.getShapedType())) return WalkResult::advance(); LLVM_DEBUG(DBGS() << "Candidate for hoisting: " diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -162,7 +162,7 @@ SmallVector, 8> indexing; SmallVector outputBuffers; for (OpOperand *outputOperand : linalgOp.getDpsInitOperands()) { - if (!outputOperand->get().getType().isa()) + if (!isa(outputOperand->get().getType())) continue; indexing.push_back(makeCanonicalAffineApplies( b, loc, linalgOp.getMatchingIndexingMap(outputOperand), @@ -242,7 +242,7 @@ return failure(); // The induction variable is a block argument of the entry block of the // loop operation. - BlockArgument ivVal = iv.dyn_cast(); + BlockArgument ivVal = dyn_cast(iv); if (!ivVal) return failure(); loopSet.insert(ivVal.getOwner()->getParentOp()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp @@ -44,9 +44,9 @@ auto result = operation->getResult(0); - auto kernelTy = kernel.getType().dyn_cast(); - auto initTy = init.getType().dyn_cast(); - auto resultTy = result.getType().template dyn_cast(); + auto kernelTy = dyn_cast(kernel.getType()); + auto initTy = dyn_cast(init.getType()); + auto resultTy = dyn_cast(result.getType()); if (!kernelTy || !initTy || !resultTy) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -292,9 +292,9 @@ }) .Case([&](ComplexType t) { Value tmp; - if (auto et = t.getElementType().dyn_cast()) + if (auto et = dyn_cast(t.getElementType())) tmp = b.create(FloatAttr::get(et, 0.0)); - else if (auto et = t.getElementType().cast()) + else if (auto et = cast(t.getElementType())) tmp = b.create(IntegerAttr::get(et, 0)); return b.create(t, tmp, tmp); }) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -93,7 +93,7 @@ {iterationSpace[dimension].offset, iterationSpace[dimension].size, minSplitPoint}); if (auto attr = remainingSize.dyn_cast()) { - if (attr.cast().getValue().isZero()) + if (cast(attr).getValue().isZero()) return {op, TilingInterface()}; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -113,7 +113,7 @@ } Type newType = RankedTensorType::get( newShape, - operand->get().getType().cast().getElementType()); + cast(operand->get().getType()).getElementType()); Value newInput = b.create( loc, newType, operand->get(), reassociation); newInputs.push_back(newInput); @@ -309,7 +309,7 @@ fillOps.reserve(op.getNumDpsInits()); for (auto it : llvm::zip(op.getDpsInitOperands(), neutralElements)) { Value rankedTensor = std::get<0>(it)->get(); - auto t = rankedTensor.getType().cast(); + auto t = cast(rankedTensor.getType()); RankedTensorType newT = RankedTensorType::Builder(t).insertDim( reductionDimSize / splitFactor, insertSplitDimension); SmallVector dims = @@ -383,7 +383,7 @@ combinerOps)) { Value reindexedOutput = std::get<0>(it); Value originalOutput = std::get<1>(it)->get(); - auto originalOutputType = originalOutput.getType().cast(); + auto originalOutputType = cast(originalOutput.getType()); Operation *combinerOp = std::get<2>(it); AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1); diff --git a/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp @@ -65,7 +65,7 @@ findHoistableMatchingExtractSlice(RewriterBase &rewriter, tensor::InsertSliceOp insertSliceOp, BlockArgument srcTensor) { - assert(srcTensor.getType().isa() && "not a ranked tensor"); + assert(isa(srcTensor.getType()) && "not a ranked tensor"); auto forOp = cast(srcTensor.getOwner()->getParentOp()); @@ -92,7 +92,7 @@ // Skip insert_slice whose vector is defined within the loop: we need to // hoist that definition first otherwise dominance violations trigger. - if (!extractSliceOp.getSource().isa() && + if (!isa(extractSliceOp.getSource()) && !forOp.isDefinedOutsideOfLoop(extractSliceOp.getSource())) { LLVM_DEBUG(DBGS() << "------transfer_read vector is loop-dependent\n"); continue; @@ -119,7 +119,7 @@ findHoistableMatchingTransferRead(RewriterBase &rewriter, vector::TransferWriteOp transferWriteOp, BlockArgument srcTensor) { - if (!srcTensor.getType().isa()) + if (!isa(srcTensor.getType())) return failure(); auto forOp = cast(srcTensor.getOwner()->getParentOp()); @@ -152,7 +152,7 @@ // transfer_read may be of a vector that is defined within the loop: we // traverse it by virtue of bypassing disjoint subset operations rooted at // a bbArg and yielding a matching yield. - if (!read.getSource().isa() && + if (!isa(read.getSource()) && !forOp.isDefinedOutsideOfLoop(read.getSource())) { LLVM_DEBUG(DBGS() << "------transfer_read vector appears loop " "dependent but will be tested for disjointness as " diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -49,7 +49,7 @@ if (!v) return false; if (auto attr = v.dyn_cast()) { - IntegerAttr intAttr = attr.dyn_cast(); + IntegerAttr intAttr = dyn_cast(attr); return intAttr && intAttr.getValue().isZero(); } if (auto cst = v.get().getDefiningOp()) @@ -105,7 +105,7 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b, OpFoldResult value) { if (auto attr = value.dyn_cast()) { - assert(attr.cast().getValue().isStrictlyPositive() && + assert(cast(attr).getValue().isStrictlyPositive() && "expected strictly positive tile size and divisor"); return; } @@ -587,8 +587,8 @@ SmallVector loops; loops.reserve(ivs.size()); for (auto iv : ivs) { - if (iv.isa()) { - loops.push_back(iv.cast().getOwner()->getParentOp()); + if (isa(iv)) { + loops.push_back(cast(iv).getOwner()->getParentOp()); assert(loops.back() && "no owner found for induction variable!"); } else { // TODO: Instead of doing this, try to recover the ops used instead of the @@ -712,7 +712,7 @@ outOffsets[reductionDim] = forallOp.getInductionVars().front(); // TODO: use SubsetExtractOpInterface once it is available. tiledDpsInitOperands.push_back(b.create( - loc, initOperand->get().getType().cast(), + loc, cast(initOperand->get().getType()), destBbArgs[destNum], outOffsets, sizes, strides)); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -365,8 +365,7 @@ // Then create a new reduction that only reduce the newly added dimension // from the previous op. - int64_t intermRank = - partialReduce[0].getType().cast().getRank(); + int64_t intermRank = cast(partialReduce[0].getType()).getRank(); AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); SmallVector reductionIteratorTypes; SmallVector exprs; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -89,7 +89,7 @@ // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp. OpOperand *currOpOperand = opOperand; while (auto linalgOp = currOpOperand->get().getDefiningOp()) { - OpResult result = currOpOperand->get().cast(); + OpResult result = cast(currOpOperand->get()); currOpOperand = linalgOp.getDpsInitOperand(result.getResultNumber()); } @@ -139,7 +139,7 @@ // If the size is an attribute add it directly to `paddedShape`. if (en.value().is()) { paddedShape[shapeIdx++] = - en.value().get().dyn_cast().getInt(); + dyn_cast(en.value().get()).getInt(); LLVM_DEBUG( DBGS() << "------dim is an attr, add it to padded shape, SKIP\n"); continue; @@ -238,7 +238,7 @@ for (const auto &en : llvm::enumerate(paddedOp->getResults())) { Value paddedResult = en.value(); int64_t resultNumber = en.index(); - int64_t rank = paddedResult.getType().cast().getRank(); + int64_t rank = cast(paddedResult.getType()).getRank(); SmallVector offsets(rank, rewriter.getIndexAttr(0)); SmallVector sizes; SmallVector strides(rank, rewriter.getIndexAttr(1)); @@ -482,7 +482,7 @@ tensor::PackOp packOp) { // 1. Filter out NYI cases. auto packedTensorType = - packOp->getResultTypes().front().cast(); + cast(packOp->getResultTypes().front()); if (!packedTensorType.hasStaticShape()) { return rewriter.notifyMatchFailure( packOp, @@ -628,7 +628,7 @@ int64_t packedRank = packedTensorType.getRank(); OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1); - auto destTensorType = unPackOp.getDest().getType().cast(); + auto destTensorType = cast(unPackOp.getDest().getType()); if (unPackOp.isLikeUnPad()) { // This unpack is just a plain unpad. // Just extract the slice from the higher ranked tensor. @@ -878,7 +878,7 @@ // Sanity check of the expected transposed tensor type. auto tensorType = permuteShape( - opOperand.get().getType().cast(), permutation); + cast(opOperand.get().getType()), permutation); (void)tensorType; assert(tensorType == transposedValue.getType() && "expected tensor type mismatch"); @@ -1039,8 +1039,8 @@ PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const { - auto inputShapedType = padOp.getSource().getType().cast(); - auto resultShapedType = padOp.getResult().getType().cast(); + auto inputShapedType = cast(padOp.getSource().getType()); + auto resultShapedType = cast(padOp.getResult().getType()); // Bail on non-static shapes. if (!inputShapedType.hasStaticShape()) @@ -1057,7 +1057,7 @@ Operation *definingOp = padValue.getDefiningOp(); if (definingOp && definingOp->getBlock() == &block) return failure(); - if (!definingOp && padValue.cast().getOwner() == &block) + if (!definingOp && cast(padValue).getOwner() == &block) return failure(); // Create tensor with the padded shape @@ -1123,7 +1123,7 @@ return val; return rewriter .create( - padOp.getLoc(), ofr.get().cast().getInt()) + padOp.getLoc(), cast(ofr.get()).getInt()) .getResult(); }; @@ -1503,9 +1503,9 @@ Value kernel = convOp.getInputs().back(); Value output = convOp.getOutputs().front(); - auto inputType = input.getType().dyn_cast(); - auto kernelType = kernel.getType().dyn_cast(); - auto outputType = output.getType().dyn_cast(); + auto inputType = dyn_cast(input.getType()); + auto kernelType = dyn_cast(kernel.getType()); + auto outputType = dyn_cast(output.getType()); auto kernelShape = kernelType.getShape(); auto outputShape = outputType.getShape(); @@ -1627,9 +1627,9 @@ Value kernel = convOp.getInputs().back(); Value output = convOp.getOutputs().front(); - auto inputType = input.getType().dyn_cast(); - auto kernelType = kernel.getType().dyn_cast(); - auto outputType = output.getType().dyn_cast(); + auto inputType = dyn_cast(input.getType()); + auto kernelType = dyn_cast(kernel.getType()); + auto outputType = dyn_cast(output.getType()); auto kernelShape = kernelType.getShape(); auto outputShape = outputType.getShape(); @@ -1695,9 +1695,9 @@ Value kernel = convOp.getInputs().back(); Value output = convOp.getOutputs().front(); - auto inputType = input.getType().dyn_cast(); - auto kernelType = kernel.getType().dyn_cast(); - auto outputType = output.getType().dyn_cast(); + auto inputType = dyn_cast(input.getType()); + auto kernelType = dyn_cast(kernel.getType()); + auto outputType = dyn_cast(output.getType()); auto kernelShape = kernelType.getShape(); auto outputShape = outputType.getShape(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -563,7 +563,7 @@ loc, value, outputOperand->get(), indices, writeMap); } else { // 0-d case is still special: do not invert the reindexing writeMap. - if (!value.getType().isa()) + if (!isa(value.getType())) value = rewriter.create(loc, vectorType, value); assert(value.getType() == vectorType && "incorrect type"); write = rewriter.create( @@ -864,7 +864,7 @@ targetShape.back() == 1) return VectorMemoryAccessKind::Gather; - auto inputShape = extractOp.getTensor().getType().cast(); + auto inputShape = cast(extractOp.getTensor().getType()); // 2. Assume that it's a gather load when reading _from_ a tensor for which // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`. @@ -1024,8 +1024,8 @@ const IRMapping &bvm) { Value reduceVec = bvm.lookup(reduceValue); Value outputVec = bvm.lookup(initialValue); - auto reduceType = reduceVec.getType().dyn_cast(); - auto outputType = outputVec.getType().dyn_cast(); + auto reduceType = dyn_cast(reduceVec.getType()); + auto outputType = dyn_cast(outputVec.getType()); // Reduce only if needed as the value may already have been reduce for // contraction vectorization. if (!reduceType || @@ -1082,7 +1082,7 @@ // 4 . Check if the operation is a reduction. SmallVector> reductionOperands; for (Value operand : op->getOperands()) { - auto blockArg = operand.dyn_cast(); + auto blockArg = dyn_cast(operand); if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() || blockArg.getArgNumber() < linalgOp.getNumDpsInputs()) continue; @@ -1107,7 +1107,7 @@ // a. first get the first max ranked shape. SmallVector firstMaxRankedShape; for (Value operand : op->getOperands()) { - auto vt = bvm.lookup(operand).getType().dyn_cast(); + auto vt = dyn_cast(bvm.lookup(operand).getType()); if (vt && firstMaxRankedShape.size() < vt.getShape().size()) firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end()); } @@ -1230,7 +1230,7 @@ // 3.c. Not all ops support 0-d vectors, extract the scalar for now. // TODO: remove this. - if (readValue.getType().cast().getRank() == 0) + if (cast(readValue.getType()).getRank() == 0) readValue = rewriter.create(loc, readValue); LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue @@ -1528,8 +1528,8 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter, memref::CopyOp copyOp) { - auto srcType = copyOp.getSource().getType().cast(); - auto dstType = copyOp.getTarget().getType().cast(); + auto srcType = cast(copyOp.getSource().getType()); + auto dstType = cast(copyOp.getTarget().getType()); if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) return failure(); @@ -1549,7 +1549,7 @@ Value readValue = rewriter.create( loc, readType, copyOp.getSource(), indices, rewriter.getMultiDimIdentityMap(srcType.getRank())); - if (readValue.getType().cast().getRank() == 0) { + if (cast(readValue.getType()).getRank() == 0) { readValue = rewriter.create(loc, readValue); readValue = rewriter.create(loc, writeType, readValue); } @@ -1566,7 +1566,7 @@ /// Helper function that retrieves the value of an IntegerAttr. static int64_t getIntFromAttr(Attribute attr) { - return attr.cast().getInt(); + return cast(attr).getInt(); } /// Given an ArrayRef of OpFoldResults, return a vector of Values. @@ -1836,8 +1836,8 @@ if (hasSameTensorSize(castOp.getSource(), afterTrimming)) return true; - auto t1 = beforePadding.getType().dyn_cast(); - auto t2 = afterTrimming.getType().dyn_cast(); + auto t1 = dyn_cast(beforePadding.getType()); + auto t2 = dyn_cast(afterTrimming.getType()); // Only RankedTensorType supported. if (!t1 || !t2) return false; @@ -1946,7 +1946,7 @@ if (!padValue) return failure(); // Dynamic shapes not supported. - if (!padOp.getResult().getType().cast().hasStaticShape()) + if (!cast(padOp.getResult().getType()).hasStaticShape()) return failure(); // Pad result not used as destination. if (insertOp.getDest() == padOp.getResult()) @@ -2074,7 +2074,7 @@ memref::CopyOp copyOp; for (auto &u : subView.getUses()) { if (auto newCopyOp = dyn_cast(u.getOwner())) { - assert(newCopyOp.getTarget().getType().isa()); + assert(isa(newCopyOp.getTarget().getType())); if (newCopyOp.getTarget() != subView) continue; if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView})) @@ -2091,7 +2091,7 @@ FillOp maybeFillOp; for (auto &u : viewOrAlloc.getUses()) { if (auto newFillOp = dyn_cast(u.getOwner())) { - assert(newFillOp.output().getType().isa()); + assert(isa(newFillOp.output().getType())); if (newFillOp.output() != viewOrAlloc) continue; if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView})) @@ -2162,7 +2162,7 @@ return rewriter.notifyMatchFailure(xferOp, "no copy found"); // `out` is the subview copied into that we replace. - assert(copyOp.getTarget().getType().isa()); + assert(isa(copyOp.getTarget().getType())); Value out = copyOp.getTarget(); // Forward vector.transfer into copy. @@ -2204,7 +2204,7 @@ namespace { bool isCastOfBlockArgument(Operation *op) { return isa(op) && op->getNumOperands() == 1 && - op->getOperand(0).isa(); + isa(op->getOperand(0)); } bool isSupportedPoolKind(vector::CombiningKind kind) { @@ -2268,9 +2268,9 @@ lhsShaped = linalgOp.getDpsInputOperand(0)->get(); rhsShaped = linalgOp.getDpsInputOperand(1)->get(); resShaped = linalgOp.getDpsInitOperand(0)->get(); - lhsShapedType = lhsShaped.getType().dyn_cast(); - rhsShapedType = rhsShaped.getType().dyn_cast(); - resShapedType = resShaped.getType().dyn_cast(); + lhsShapedType = dyn_cast(lhsShaped.getType()); + rhsShapedType = dyn_cast(rhsShaped.getType()); + resShapedType = dyn_cast(resShaped.getType()); if (!lhsShapedType || !rhsShapedType || !resShapedType) return; // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR @@ -2717,8 +2717,8 @@ /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc, Value lhs, Value rhs, Value res) { - auto rhsTy = rhs.getType().cast(); - auto resTy = res.getType().cast(); + auto rhsTy = cast(rhs.getType()); + auto resTy = cast(res.getType()); // TODO(suderman): Change this to use a vector.ima intrinsic. lhs = promote(rewriter, loc, lhs, resTy); @@ -2730,7 +2730,7 @@ if (!lhs || !rhs) return nullptr; - if (resTy.getElementType().isa()) + if (isa(resTy.getElementType())) return rewriter.create(loc, lhs, rhs, res); auto mul = rewriter.create(loc, lhs, rhs); @@ -2863,15 +2863,14 @@ // Otherwise, check for one or zero `ext` predecessor. The `ext` operands // must be block arguments or extension of block arguments. bool setOperKind(Operation *reduceOp) { - int numBlockArguments = - llvm::count_if(reduceOp->getOperands(), - [](Value v) { return v.isa(); }); + int numBlockArguments = llvm::count_if( + reduceOp->getOperands(), [](Value v) { return isa(v); }); switch (numBlockArguments) { case 1: { // Will be convolution if feeder is a MulOp. // Otherwise, if it can be pooling. auto feedValIt = llvm::find_if(reduceOp->getOperands(), [](Value v) { - return !v.isa(); + return !isa(v); }); Operation *feedOp = (*feedValIt).getDefiningOp(); if (isCastOfBlockArgument(feedOp)) { @@ -2880,7 +2879,7 @@ poolExtOp = feedOp->getName().getIdentifier(); } else if (!(isa(feedOp) && llvm::all_of(feedOp->getOperands(), [](Value v) { - if (v.isa()) + if (isa(v)) return true; if (Operation *op = v.getDefiningOp()) return isCastOfBlockArgument(op); diff --git a/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp @@ -43,16 +43,16 @@ namespace mlir { namespace linalg { Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim) { - if (val.getType().isa()) + if (isa(val.getType())) return b.createOrFold(loc, val, dim); - if (val.getType().isa()) + if (isa(val.getType())) return b.createOrFold(loc, val, dim); llvm_unreachable("Expected MemRefType or TensorType"); } OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim) { - auto shapedType = val.getType().cast(); + auto shapedType = cast(val.getType()); if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) return createOrFoldDimOp(b, loc, val, dim); return b.getIndexAttr(shapedType.getDimSize(dim)); @@ -60,7 +60,7 @@ SmallVector createDynamicDimensions(OpBuilder &b, Location loc, Value val) { - auto shapedType = val.getType().cast(); + auto shapedType = cast(val.getType()); assert(shapedType.hasRank() && "`val` must have a static rank"); SmallVector res; res.reserve(shapedType.getRank()); @@ -73,7 +73,7 @@ SmallVector getMixedDimensions(OpBuilder &b, Location loc, Value val) { - auto shapedType = val.getType().cast(); + auto shapedType = cast(val.getType()); assert(shapedType.hasRank() && "`val` must have a static rank"); SmallVector dynamicDims = createDynamicDimensions(b, loc, val); return getMixedValues(shapedType.getShape(), dynamicDims, b); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -281,7 +281,7 @@ auto linalgOp = current.getDefiningOp(); if (!linalgOp) break; - OpResult opResult = current.cast(); + OpResult opResult = cast(current); current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get(); } auto padOp = current ? current.getDefiningOp() : nullptr; @@ -331,7 +331,7 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor, Value outputTensor, ArrayRef transposeVector) { - auto resultTensorType = outputTensor.getType().cast(); + auto resultTensorType = cast(outputTensor.getType()); Type elementType = resultTensorType.getElementType(); assert(isPermutationVector(transposeVector) && @@ -366,9 +366,9 @@ } GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) { - auto memrefTypeTo = to.getType().cast(); + auto memrefTypeTo = cast(to.getType()); #ifndef NDEBUG - auto memrefTypeFrom = from.getType().cast(); + auto memrefTypeFrom = cast(from.getType()); assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() && "`from` and `to` memref must have the same rank"); #endif // NDEBUG @@ -650,7 +650,7 @@ static Value materializeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, const SliceParameters &sliceParams) { - auto shapedType = valueToTile.getType().dyn_cast(); + auto shapedType = dyn_cast(valueToTile.getType()); auto *sliceOp = TypeSwitch(shapedType) .Case([&](MemRefType) { return builder.create( @@ -685,7 +685,7 @@ ArrayRef lbs, ArrayRef ubs, ArrayRef subShapeSizes, bool omitPartialTileCheck) { - auto shapedType = valueToTile.getType().dyn_cast(); + auto shapedType = dyn_cast(valueToTile.getType()); assert(shapedType && "only shaped types can be tiled"); ArrayRef shape = shapedType.getShape(); int64_t rank = shapedType.getRank(); @@ -889,7 +889,7 @@ // subdomains explicit. Type operandType = opOperand.get().getType(); - if (!isTiled(map, tileSizes) && !(operandType.isa() && + if (!isTiled(map, tileSizes) && !(isa(operandType) && linalgOp.isDpsInit(&opOperand))) { allSliceParams.push_back(std::nullopt); LLVM_DEBUG(llvm::dbgs() @@ -971,7 +971,7 @@ auto size = it.value(); curr.push_back(dim); auto attr = size.dyn_cast(); - if (attr && attr.cast().getInt() == 1) + if (attr && cast(attr).getInt() == 1) continue; reassociation.emplace_back(ReassociationIndices{}); std::swap(reassociation.back(), curr); @@ -989,7 +989,7 @@ // Builder only used as helper for attribute creation. OpBuilder b(op->getContext()); Type resultType = op->getResult(0).getType(); - if (auto floatType = resultType.dyn_cast()) { + if (auto floatType = dyn_cast(resultType)) { const llvm::fltSemantics &semantic = floatType.getFloatSemantics(); if (isa(op)) return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic)); diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp --- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp +++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp @@ -64,7 +64,7 @@ // Maybe broadcasts scalar value into vector type compatible with `op`. auto bcast = [&](Value value) -> Value { - if (auto vec = op.getType().dyn_cast()) + if (auto vec = dyn_cast(op.getType())) return rewriter.create(op.getLoc(), vec, value); return value; }; @@ -167,7 +167,7 @@ // Maybe broadcasts scalar value into vector type compatible with `op`. auto bcast = [&loc, &op, &rewriter](Value value) -> Value { - if (auto vec = op.getType().template dyn_cast()) + if (auto vec = dyn_cast(op.getType())) return rewriter.create(loc, vec, value); return value; }; diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -40,7 +40,7 @@ // Returns vector shape if the type is a vector. Returns an empty shape if it is // not a vector. static ArrayRef vectorShape(Type type) { - auto vectorType = type.dyn_cast(); + auto vectorType = dyn_cast(type); return vectorType ? vectorType.getShape() : ArrayRef(); } @@ -54,14 +54,14 @@ // Broadcasts scalar type into vector type (iff shape is non-scalar). static Type broadcast(Type type, ArrayRef shape) { - assert(!type.isa() && "must be scalar type"); + assert(!isa(type) && "must be scalar type"); return !shape.empty() ? VectorType::get(shape, type) : type; } // Broadcasts scalar value into vector (iff shape is non-scalar). static Value broadcast(ImplicitLocOpBuilder &builder, Value value, ArrayRef shape) { - assert(!value.getType().isa() && "must be scalar value"); + assert(!isa(value.getType()) && "must be scalar value"); auto type = broadcast(value.getType(), shape); return !shape.empty() ? builder.create(type, value) : value; } @@ -92,7 +92,7 @@ assert(!operands.empty() && "operands must be not empty"); assert(vectorWidth > 0 && "vector width must be larger than 0"); - VectorType inputType = operands[0].getType().cast(); + VectorType inputType = cast(operands[0].getType()); ArrayRef inputShape = inputType.getShape(); // If input shape matches target vector width, we can just call the @@ -118,7 +118,7 @@ for (unsigned i = 0; i < operands.size(); ++i) { auto operand = operands[i]; - auto eltType = operand.getType().cast().getElementType(); + auto eltType = cast(operand.getType()).getElementType(); auto expandedType = VectorType::get(expandedShape, eltType); expandedOperands[i] = builder.create(expandedType, operand); @@ -145,7 +145,7 @@ } // Stitch results together into one large vector. - Type resultEltType = results[0].getType().cast().getElementType(); + Type resultEltType = cast(results[0].getType()).getElementType(); Type resultExpandedType = VectorType::get(expandedShape, resultEltType); Value result = builder.create( resultExpandedType, builder.getZeroAttr(resultExpandedType)); @@ -318,9 +318,9 @@ // Create F32 equivalent type. Type newType; - if (auto shaped = origType.dyn_cast()) { + if (auto shaped = dyn_cast(origType)) { newType = shaped.clone(rewriter.getF32Type()); - } else if (origType.isa()) { + } else if (isa(origType)) { newType = rewriter.getF32Type(); } else { return rewriter.notifyMatchFailure(op, diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -69,7 +69,7 @@ results.push_back(*newBuffer); } - transformResults.set(getResult().cast(), results); + transformResults.set(cast(getResult()), results); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp @@ -57,7 +57,7 @@ // always 1. if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) { Attribute attr = valueOrAttr.dyn_cast(); - return attr && attr.cast().getInt() == 1; + return attr && cast(attr).getInt() == 1; })) { strides = SmallVector(sourceOp.getMixedStrides().size(), rewriter.getI64IntegerAttr(1)); @@ -93,8 +93,8 @@ // If both offsets are static we can simply calculate the combined // offset statically. offsets.push_back(rewriter.getI64IntegerAttr( - opOffsetAttr.cast().getInt() + - sourceOffsetAttr.cast().getInt())); + cast(opOffsetAttr).getInt() + + cast(sourceOffsetAttr).getInt())); } else { // When either offset is dynamic, we must emit an additional affine // transformation to add the two offsets together dynamically. @@ -102,7 +102,7 @@ SmallVector affineApplyOperands; for (auto valueOrAttr : {opOffset, sourceOffset}) { if (auto attr = valueOrAttr.dyn_cast()) { - expr = expr + attr.cast().getInt(); + expr = expr + cast(attr).getInt(); } else { expr = expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp @@ -149,7 +149,7 @@ arith::WideIntEmulationConverter &typeConverter) { typeConverter.addConversion( [&typeConverter](MemRefType ty) -> std::optional { - auto intTy = ty.getElementType().dyn_cast(); + auto intTy = dyn_cast(ty.getElementType()); if (!intTy) return ty; diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp @@ -89,11 +89,11 @@ LogicalResult matchAndRewrite(memref::ReshapeOp op, PatternRewriter &rewriter) const final { - auto shapeType = op.getShape().getType().cast(); + auto shapeType = cast(op.getShape().getType()); if (!shapeType.hasStaticShape()) return failure(); - int64_t rank = shapeType.cast().getDimSize(0); + int64_t rank = cast(shapeType).getDimSize(0); SmallVector sizes, strides; sizes.resize(rank); strides.resize(rank); @@ -106,7 +106,7 @@ if (op.getType().isDynamicDim(i)) { Value index = rewriter.create(loc, i); size = rewriter.create(loc, op.getShape(), index); - if (!size.getType().isa()) + if (!isa(size.getType())) size = rewriter.create( loc, rewriter.getIndexType(), size); sizes[i] = size; @@ -141,7 +141,7 @@ op.getKind() != arith::AtomicRMWKind::minf; }); target.addDynamicallyLegalOp([](memref::ReshapeOp op) { - return !op.getShape().getType().cast().hasStaticShape(); + return !cast(op.getShape().getType()).hasStaticShape(); }); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -62,7 +62,7 @@ // Build a plain extract_strided_metadata(memref) from subview(memref). Location origLoc = subview.getLoc(); Value source = subview.getSource(); - auto sourceType = source.getType().cast(); + auto sourceType = cast(source.getType()); unsigned sourceRank = sourceType.getRank(); auto newExtractStridedMetadata = @@ -115,7 +115,7 @@ // The final result is . // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all // the values. - auto subType = subview.getType().cast(); + auto subType = cast(subview.getType()); unsigned subRank = subType.getRank(); // The sizes of the final type are defined directly by the input sizes of @@ -338,7 +338,7 @@ // Collect the statically known information about the original stride. Value source = expandShape.getSrc(); - auto sourceType = source.getType().cast(); + auto sourceType = cast(source.getType()); auto [strides, offset] = getStridesAndOffset(sourceType); OpFoldResult origStride = ShapedType::isDynamic(strides[groupId]) @@ -358,10 +358,9 @@ AffineExpr s0 = builder.getAffineSymbolExpr(0); AffineExpr s1 = builder.getAffineSymbolExpr(1); for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) { - int64_t baseExpandedStride = expandedStrides[doneStrideIdx] - .get() - .cast() - .getInt(); + int64_t baseExpandedStride = + cast(expandedStrides[doneStrideIdx].get()) + .getInt(); expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( builder, expandShape.getLoc(), (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1, @@ -372,10 +371,9 @@ // Now apply the origStride to the remaining dimensions. AffineExpr s0 = builder.getAffineSymbolExpr(0); for (; doneStrideIdx < groupSize; ++doneStrideIdx) { - int64_t baseExpandedStride = expandedStrides[doneStrideIdx] - .get() - .cast() - .getInt(); + int64_t baseExpandedStride = + cast(expandedStrides[doneStrideIdx].get()) + .getInt(); expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride}); } @@ -445,7 +443,7 @@ // Build the affine expr of the product of the original sizes involved in that // group. Value source = collapseShape.getSrc(); - auto sourceType = source.getType().cast(); + auto sourceType = cast(source.getType()); SmallVector reassocGroup = collapseShape.getReassociationIndices()[groupId]; @@ -479,7 +477,7 @@ "Reassociation group should have at least one dimension"); Value source = collapseShape.getSrc(); - auto sourceType = source.getType().cast(); + auto sourceType = cast(source.getType()); auto [strides, offset] = getStridesAndOffset(sourceType); @@ -562,7 +560,7 @@ // extract_strided_metadata(reassociative_reshape_like(memref)). Location origLoc = reshape.getLoc(); Value source = reshape.getSrc(); - auto sourceType = source.getType().cast(); + auto sourceType = cast(source.getType()); unsigned sourceRank = sourceType.getRank(); auto newExtractStridedMetadata = @@ -650,8 +648,7 @@ if (!allocLikeOp) return failure(); - auto memRefType = - allocLikeOp.getResult().getType().template cast(); + auto memRefType = cast(allocLikeOp.getResult().getType()); if (!memRefType.getLayout().isIdentity()) return rewriter.notifyMatchFailure( allocLikeOp, "alloc-like operations should have been normalized"); @@ -688,7 +685,7 @@ SmallVector results; results.reserve(rank * 2 + 2); - auto baseBufferType = op.getBaseBuffer().getType().cast(); + auto baseBufferType = cast(op.getBaseBuffer().getType()); int64_t offset = 0; if (allocLikeOp.getType() == baseBufferType) results.push_back(allocLikeOp); @@ -737,7 +734,7 @@ if (!getGlobalOp) return failure(); - auto memRefType = getGlobalOp.getResult().getType().cast(); + auto memRefType = cast(getGlobalOp.getResult().getType()); if (!memRefType.getLayout().isIdentity()) { return rewriter.notifyMatchFailure( getGlobalOp, @@ -759,7 +756,7 @@ SmallVector results; results.reserve(rank * 2 + 2); - auto baseBufferType = op.getBaseBuffer().getType().cast(); + auto baseBufferType = cast(op.getBaseBuffer().getType()); int64_t offset = 0; if (getGlobalOp.getType() == baseBufferType) results.push_back(getGlobalOp); @@ -838,8 +835,7 @@ return rewriter.notifyMatchFailure( reinterpretCastOp, "reinterpret_cast source's type is incompatible"); - auto memrefType = - reinterpretCastOp.getResult().getType().cast(); + auto memrefType = cast(reinterpretCastOp.getResult().getType()); unsigned rank = memrefType.getRank(); SmallVector results; results.resize_for_overwrite(rank * 2 + 2); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp @@ -120,7 +120,7 @@ static FailureOr getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) { Value src = transferLikeOp.getSource(); - if (src.getType().isa()) + if (isa(src.getType())) return src; return failure(); } @@ -240,7 +240,7 @@ return rewriter.notifyMatchFailure(loadStoreLikeOp, "source is not a memref"); Value srcMemRef = *failureOrSrcMemRef; - auto ldStTy = srcMemRef.getType().cast(); + auto ldStTy = cast(srcMemRef.getType()); unsigned loadStoreRank = ldStTy.getRank(); // Don't waste compile time if there is nothing to rewrite. if (loadStoreRank == 0) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -148,7 +148,7 @@ if (collapseShapeOp.getReassociationIndices().empty()) { auto zeroAffineMap = rewriter.getConstantAffineMap(0); int64_t srcRank = - collapseShapeOp.getViewSource().getType().cast().getRank(); + cast(collapseShapeOp.getViewSource().getType()).getRank(); for (int64_t i = 0; i < srcRank; i++) { OpFoldResult ofr = affine::makeComposedFoldedAffineApply( rewriter, loc, zeroAffineMap, dynamicIndices); diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp @@ -71,11 +71,9 @@ UnrealizedConversionCastOp conversionOp, SubViewOp op) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); - auto newResultType = - SubViewOp::inferRankReducedResultType( - op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(), - op.getMixedSizes(), op.getMixedStrides()) - .cast(); + auto newResultType = cast(SubViewOp::inferRankReducedResultType( + op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(), + op.getMixedSizes(), op.getMixedStrides())); Value newSubview = rewriter.create( op.getLoc(), newResultType, conversionOp.getOperand(0), op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -61,11 +61,11 @@ OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(subviewUse); Type newType = memref::SubViewOp::inferRankReducedResultType( - subviewUse.getType().getShape(), val.getType().cast(), + subviewUse.getType().getShape(), cast(val.getType()), subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), subviewUse.getStaticStrides()); Value newSubview = rewriter.create( - subviewUse->getLoc(), newType.cast(), val, + subviewUse->getLoc(), cast(newType), val, subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), subviewUse.getMixedStrides()); @@ -209,9 +209,9 @@ for (int64_t i = 0, e = originalShape.size(); i != e; ++i) sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]); // Strides is [1, 1 ... 1 ]. - auto dstMemref = memref::SubViewOp::inferRankReducedResultType( - originalShape, mbMemRefType, offsets, sizes, strides) - .cast(); + auto dstMemref = + cast(memref::SubViewOp::inferRankReducedResultType( + originalShape, mbMemRefType, offsets, sizes, strides)); Value subview = rewriter.create(loc, dstMemref, mbAlloc, offsets, sizes, strides); LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n"); diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -180,7 +180,7 @@ llvm::seq(0, callOp.getNumResults())) { Value oldMemRef = callOp.getResult(resIndex); if (auto oldMemRefType = - oldMemRef.getType().dyn_cast()) + dyn_cast(oldMemRef.getType())) if (!oldMemRefType.getLayout().isIdentity() && !isMemRefNormalizable(oldMemRef.getUsers())) return WalkResult::interrupt(); @@ -192,7 +192,7 @@ for (unsigned argIndex : llvm::seq(0, funcOp.getNumArguments())) { BlockArgument oldMemRef = funcOp.getArgument(argIndex); - if (auto oldMemRefType = oldMemRef.getType().dyn_cast()) + if (auto oldMemRefType = dyn_cast(oldMemRef.getType())) if (!oldMemRefType.getLayout().isIdentity() && !isMemRefNormalizable(oldMemRef.getUsers())) return false; @@ -226,7 +226,7 @@ funcOp.walk([&](func::ReturnOp returnOp) { for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) { Type opType = operandEn.value().getType(); - MemRefType memrefType = opType.dyn_cast(); + MemRefType memrefType = dyn_cast(opType); // If type is not memref or if the memref type is same as that in // function's return signature then no update is required. if (!memrefType || memrefType == resultTypes[operandEn.index()]) @@ -284,7 +284,7 @@ if (oldResult.getType() == newResult.getType()) continue; AffineMap layoutMap = - oldResult.getType().cast().getLayout().getAffineMap(); + cast(oldResult.getType()).getLayout().getAffineMap(); if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult, /*extraIndices=*/{}, /*indexRemap=*/layoutMap, @@ -358,7 +358,7 @@ for (unsigned argIndex : llvm::seq(0, functionType.getNumInputs())) { Type argType = functionType.getInput(argIndex); - MemRefType memrefType = argType.dyn_cast(); + MemRefType memrefType = dyn_cast(argType); // Check whether argument is of MemRef type. Any other argument type can // simply be part of the final function signature. if (!memrefType) { @@ -422,11 +422,11 @@ // Replace all uses of the old memrefs. Value oldMemRef = op->getResult(resIndex); Value newMemRef = newOp->getResult(resIndex); - MemRefType oldMemRefType = oldMemRef.getType().dyn_cast(); + MemRefType oldMemRefType = dyn_cast(oldMemRef.getType()); // Check whether the operation result is MemRef type. if (!oldMemRefType) continue; - MemRefType newMemRefType = newMemRef.getType().cast(); + MemRefType newMemRefType = cast(newMemRef.getType()); if (oldMemRefType == newMemRefType) continue; // TODO: Assume single layout map. Multiple maps not supported. @@ -466,7 +466,7 @@ for (unsigned resIndex : llvm::seq(0, functionType.getNumResults())) { Type resType = functionType.getResult(resIndex); - MemRefType memrefType = resType.dyn_cast(); + MemRefType memrefType = dyn_cast(resType); // Check whether result is of MemRef type. Any other argument type can // simply be part of the final function signature. if (!memrefType) { @@ -507,7 +507,7 @@ bool resultTypeNormalized = false; for (unsigned resIndex : llvm::seq(0, oldOp->getNumResults())) { auto resultType = oldOp->getResult(resIndex).getType(); - MemRefType memrefType = resultType.dyn_cast(); + MemRefType memrefType = dyn_cast(resultType); // Check whether the operation result is MemRef type. if (!memrefType) { resultTypes.push_back(resultType); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -40,7 +40,7 @@ LogicalResult matchAndRewrite(OpTy dimOp, PatternRewriter &rewriter) const override { - OpResult dimValue = dimOp.getSource().template dyn_cast(); + OpResult dimValue = dyn_cast(dimOp.getSource()); if (!dimValue) return failure(); auto shapedTypeOp = @@ -61,8 +61,8 @@ return failure(); Value resultShape = reifiedResultShapes[dimValue.getResultNumber()]; - auto resultShapeType = resultShape.getType().dyn_cast(); - if (!resultShapeType || !resultShapeType.getElementType().isa()) + auto resultShapeType = dyn_cast(resultShape.getType()); + if (!resultShapeType || !isa(resultShapeType.getElementType())) return failure(); Location loc = dimOp->getLoc(); @@ -82,7 +82,7 @@ LogicalResult matchAndRewrite(OpTy dimOp, PatternRewriter &rewriter) const override { - OpResult dimValue = dimOp.getSource().template dyn_cast(); + OpResult dimValue = dyn_cast(dimOp.getSource()); if (!dimValue) return failure(); std::optional dimIndex = dimOp.getConstantIndex(); diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -38,14 +38,14 @@ void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc) const { auto castOp = cast(op); - auto srcType = castOp.getSource().getType().cast(); + auto srcType = cast(castOp.getSource().getType()); // Nothing to check if the result is an unranked memref. - auto resultType = castOp.getType().dyn_cast(); + auto resultType = dyn_cast(castOp.getType()); if (!resultType) return; - if (srcType.isa()) { + if (isa(srcType)) { // Check rank. Value srcRank = builder.create(loc, castOp.getSource()); Value resultRank = @@ -75,7 +75,7 @@ // Check dimension sizes. for (const auto &it : llvm::enumerate(resultType.getShape())) { // Static dim size -> static/dynamic dim size does not need verification. - if (auto rankedSrcType = srcType.dyn_cast()) + if (auto rankedSrcType = dyn_cast(srcType)) if (!rankedSrcType.isDynamicDim(it.index())) continue; diff --git a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp --- a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp @@ -42,7 +42,7 @@ Location location = op->getLoc(); if (op->hasAttr(op.getTf32EnabledAttrName()) || - !op.getMatrixA().getType().cast().getElementType().isF32()) + !cast(op.getMatrixA().getType()).getElementType().isF32()) return failure(); if (precision == MmaSyncF32Lowering::Unkown) diff --git a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp --- a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp @@ -180,7 +180,7 @@ mlir::LogicalResult mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp, Value memrefValue) { - auto memRefType = memrefValue.getType().dyn_cast(); + auto memRefType = dyn_cast(memrefValue.getType()); if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType)) return failure(); diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp --- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp +++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp @@ -63,7 +63,7 @@ info.vectorType = writeOp.getVectorType(); } else if (isa(op)) { - info.vectorType = op->getResult(0).getType().cast(); + info.vectorType = cast(op->getResult(0).getType()); } else { return op->emitError() << "unhandled operation type in nvgpu.mma.sync conversion path"; diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp --- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp +++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp @@ -14,13 +14,13 @@ using namespace mlir::quant; static bool isQuantizablePrimitiveType(Type inputType) { - return inputType.isa(); + return isa(inputType); } ExpressedToQuantizedConverter ExpressedToQuantizedConverter::forInputType(Type inputType) { - if (inputType.isa()) { - Type elementType = inputType.cast().getElementType(); + if (isa(inputType)) { + Type elementType = cast(inputType).getElementType(); if (!isQuantizablePrimitiveType(elementType)) return ExpressedToQuantizedConverter{inputType, nullptr}; return ExpressedToQuantizedConverter{inputType, elementType}; @@ -34,11 +34,11 @@ Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const { assert(expressedType && "convert() on unsupported conversion"); - if (auto tensorType = inputType.dyn_cast()) + if (auto tensorType = dyn_cast(inputType)) return RankedTensorType::get(tensorType.getShape(), elementalType); - if (auto tensorType = inputType.dyn_cast()) + if (auto tensorType = dyn_cast(inputType)) return UnrankedTensorType::get(elementalType); - if (auto vectorType = inputType.dyn_cast()) + if (auto vectorType = dyn_cast(inputType)) return VectorType::get(vectorType.getShape(), elementalType); // If the expressed types match, just use the new elemental type. @@ -50,7 +50,7 @@ ElementsAttr UniformQuantizedPerAxisValueConverter::convert(Attribute realValue) { - if (auto attr = realValue.dyn_cast()) { + if (auto attr = dyn_cast(realValue)) { return convert(attr); } // TODO: handles sparse elements attribute diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -49,7 +49,7 @@ } parents.insert(loop); } - results.set(getResult().cast(), parents.getArrayRef()); + results.set(cast(getResult()), parents.getArrayRef()); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -30,8 +30,8 @@ /// Helper function for loop bufferization. Cast the given buffer to the given /// memref type. static Value castBuffer(OpBuilder &b, Value buffer, Type type) { - assert(type.isa() && "expected BaseMemRefType"); - assert(buffer.getType().isa() && "expected BaseMemRefType"); + assert(isa(type) && "expected BaseMemRefType"); + assert(isa(buffer.getType()) && "expected BaseMemRefType"); // If the buffer already has the correct type, no cast is needed. if (buffer.getType() == type) return buffer; @@ -78,7 +78,7 @@ SmallVector newArgs; for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { Value value = it.value(); - if (value.getType().isa()) { + if (isa(value.getType())) { FailureOr maybeBuffer = getBuffer(rewriter, value, options); if (failed(maybeBuffer)) return failure(); @@ -141,7 +141,7 @@ rewriter.setInsertionPointAfter(newOp); SmallVector newResults; for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { - if (it.value().isa()) { + if (isa(it.value())) { newResults.push_back(rewriter.create( executeRegionOp.getLoc(), newOp->getResult(it.index()))); } else { @@ -183,7 +183,7 @@ // Compute bufferized result types. SmallVector newTypes; for (Value result : ifOp.getResults()) { - if (!result.getType().isa()) { + if (!isa(result.getType())) { newTypes.push_back(result.getType()); continue; } @@ -218,13 +218,13 @@ assert(value.getDefiningOp() == op && "invalid valid"); // Determine buffer types of the true/false branches. - auto opResult = value.cast(); + auto opResult = cast(value); auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber()); auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber()); BaseMemRefType thenBufferType, elseBufferType; - if (thenValue.getType().isa()) { + if (isa(thenValue.getType())) { // True branch was already bufferized. - thenBufferType = thenValue.getType().cast(); + thenBufferType = cast(thenValue.getType()); } else { auto maybeBufferType = bufferization::getBufferType(thenValue, options, fixedTypes); @@ -232,9 +232,9 @@ return failure(); thenBufferType = *maybeBufferType; } - if (elseValue.getType().isa()) { + if (isa(elseValue.getType())) { // False branch was already bufferized. - elseBufferType = elseValue.getType().cast(); + elseBufferType = cast(elseValue.getType()); } else { auto maybeBufferType = bufferization::getBufferType(elseValue, options, fixedTypes); @@ -253,7 +253,7 @@ // Layout maps are different: Promote to fully dynamic layout map. return getMemRefTypeWithFullyDynamicLayout( - opResult.getType().cast(), thenBufferType.getMemorySpace()); + cast(opResult.getType()), thenBufferType.getMemorySpace()); } }; @@ -262,7 +262,7 @@ static DenseSet getTensorIndices(ValueRange values) { DenseSet result; for (const auto &it : llvm::enumerate(values)) - if (it.value().getType().isa()) + if (isa(it.value().getType())) result.insert(it.index()); return result; } @@ -275,8 +275,8 @@ unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size()); DenseSet result; for (unsigned int i = 0; i < minSize; ++i) { - if (!bbArgs[i].getType().isa() || - !yieldedValues[i].getType().isa()) + if (!isa(bbArgs[i].getType()) || + !isa(yieldedValues[i].getType())) continue; if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i])) result.insert(i); @@ -291,7 +291,7 @@ const BufferizationOptions &options) { SmallVector result; for (OpOperand &opOperand : operands) { - if (opOperand.get().getType().isa()) { + if (isa(opOperand.get().getType())) { FailureOr resultBuffer = getBuffer(rewriter, opOperand.get(), options); if (failed(resultBuffer)) @@ -361,9 +361,9 @@ // Compute the buffer type of the yielded value. BaseMemRefType yieldedValueBufferType; - if (yieldedValue.getType().isa()) { + if (isa(yieldedValue.getType())) { // scf.yield was already bufferized. - yieldedValueBufferType = yieldedValue.getType().cast(); + yieldedValueBufferType = cast(yieldedValue.getType()); } else { auto maybeBufferType = bufferization::getBufferType(yieldedValue, options, newFixedTypes); @@ -379,7 +379,7 @@ // If there is a mismatch between the yielded buffer type and the iter_arg // buffer type, the buffer type must be promoted to a fully dynamic layout // map. - auto yieldedRanked = yieldedValueBufferType.cast(); + auto yieldedRanked = cast(yieldedValueBufferType); #ifndef NDEBUG auto iterRanked = initArgBufferType->cast(); assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) && @@ -388,7 +388,7 @@ "expected same memory space"); #endif // NDEBUG return getMemRefTypeWithFullyDynamicLayout( - iterArg.getType().cast(), + cast(iterArg.getType()), yieldedRanked.getMemorySpace()); } @@ -516,16 +516,16 @@ const DenseMap &fixedTypes) const { auto forOp = cast(op); assert(getOwnerOfValue(value) == op && "invalid value"); - assert(value.getType().isa() && "expected tensor type"); + assert(isa(value.getType()) && "expected tensor type"); // Get result/argument number. unsigned resultNum; - if (auto bbArg = value.dyn_cast()) { + if (auto bbArg = dyn_cast(value)) { resultNum = forOp.getResultForOpOperand(forOp.getOpOperandForRegionIterArg(bbArg)) .getResultNumber(); } else { - resultNum = value.cast().getResultNumber(); + resultNum = cast(value).getResultNumber(); } // Compute the bufferized type. @@ -560,7 +560,7 @@ Value initArg = it.value(); Value result = forOp->getResult(it.index()); // If the type is not a tensor, bufferization doesn't need to touch it. - if (!result.getType().isa()) { + if (!isa(result.getType())) { castedInitArgs.push_back(initArg); continue; } @@ -611,7 +611,7 @@ auto yieldOp = cast(forOp.getLoopBody().front().getTerminator()); for (OpResult opResult : op->getOpResults()) { - if (!opResult.getType().isa()) + if (!isa(opResult.getType())) continue; // Note: This is overly strict. We should check for aliasing bufferized @@ -736,7 +736,7 @@ for (int64_t idx = 0; idx < static_cast(conditionOp.getArgs().size()); ++idx) { Value value = conditionOp.getArgs()[idx]; - if (!value.getType().isa() || + if (!isa(value.getType()) || (equivalentYieldsAfter.contains(idx) && equivalentYieldsBefore.contains(idx))) { beforeYieldValues.push_back(value); @@ -786,7 +786,7 @@ Value initArg = it.value(); Value beforeArg = whileOp.getBeforeArguments()[it.index()]; // If the type is not a tensor, bufferization doesn't need to touch it. - if (!beforeArg.getType().isa()) { + if (!isa(beforeArg.getType())) { castedInitArgs.push_back(initArg); continue; } @@ -799,7 +799,7 @@ // The result types of a WhileOp are the same as the "after" bbArg types. SmallVector argsTypesAfter = llvm::to_vector( llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) { - if (!bbArg.getType().isa()) + if (!isa(bbArg.getType())) return bbArg.getType(); // TODO: error handling return bufferization::getBufferType(bbArg, options)->cast(); @@ -848,10 +848,10 @@ const DenseMap &fixedTypes) const { auto whileOp = cast(op); assert(getOwnerOfValue(value) == op && "invalid value"); - assert(value.getType().isa() && "expected tensor type"); + assert(isa(value.getType()) && "expected tensor type"); // Case 1: Block argument of the "before" region. - if (auto bbArg = value.dyn_cast()) { + if (auto bbArg = dyn_cast(value)) { if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) { Value initArg = whileOp.getInits()[bbArg.getArgNumber()]; auto yieldOp = whileOp.getYieldOp(); @@ -865,18 +865,18 @@ // The bufferized "after" bbArg type can be directly computed from the // bufferized "before" bbArg type. unsigned resultNum; - if (auto opResult = value.dyn_cast()) { + if (auto opResult = dyn_cast(value)) { resultNum = opResult.getResultNumber(); - } else if (value.cast().getOwner()->getParent() == + } else if (cast(value).getOwner()->getParent() == &whileOp.getAfter()) { - resultNum = value.cast().getArgNumber(); + resultNum = cast(value).getArgNumber(); } else { llvm_unreachable("invalid value"); } Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum]; - if (!conditionYieldedVal.getType().isa()) { + if (!isa(conditionYieldedVal.getType())) { // scf.condition was already bufferized. - return conditionYieldedVal.getType().cast(); + return cast(conditionYieldedVal.getType()); } return bufferization::getBufferType(conditionYieldedVal, options, fixedTypes); @@ -902,7 +902,7 @@ auto conditionOp = whileOp.getConditionOp(); for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { - if (!it.value().getType().isa()) + if (!isa(it.value().getType())) continue; if (!state.areEquivalentBufferizedValues( it.value(), conditionOp->getBlock()->getArgument(it.index()))) @@ -913,7 +913,7 @@ auto yieldOp = whileOp.getYieldOp(); for (const auto &it : llvm::enumerate(yieldOp.getResults())) { - if (!it.value().getType().isa()) + if (!isa(it.value().getType())) continue; if (!state.areEquivalentBufferizedValues( it.value(), yieldOp->getBlock()->getArgument(it.index()))) @@ -971,7 +971,7 @@ SmallVector newResults; for (const auto &it : llvm::enumerate(yieldOp.getResults())) { Value value = it.value(); - if (value.getType().isa()) { + if (isa(value.getType())) { FailureOr maybeBuffer = getBuffer(rewriter, value, options); if (failed(maybeBuffer)) return failure(); @@ -1110,7 +1110,7 @@ const DenseMap &fixedTypes) const { auto forallOp = cast(op); - if (auto bbArg = value.dyn_cast()) + if (auto bbArg = dyn_cast(value)) // A tensor block argument has the same bufferized type as the // corresponding output operand. return bufferization::getBufferType( @@ -1119,8 +1119,8 @@ // The bufferized result type is the same as the bufferized type of the // corresponding output operand. return bufferization::getBufferType( - forallOp.getOutputs()[value.cast().getResultNumber()], - options, fixedTypes); + forallOp.getOutputs()[cast(value).getResultNumber()], options, + fixedTypes); } bool isRepetitiveRegion(Operation *op, unsigned index) const { diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -43,7 +43,7 @@ while (value) { if (value == forOp.getRegionIterArgs()[arg]) return true; - OpResult opResult = value.dyn_cast(); + OpResult opResult = dyn_cast(value); if (!opResult) return false; @@ -91,7 +91,7 @@ LogicalResult matchAndRewrite(OpTy dimOp, PatternRewriter &rewriter) const override { - auto blockArg = dimOp.getSource().template dyn_cast(); + auto blockArg = dyn_cast(dimOp.getSource()); if (!blockArg) return failure(); auto forOp = dyn_cast(blockArg.getParentBlock()->getParentOp()); @@ -139,7 +139,7 @@ auto forOp = dimOp.getSource().template getDefiningOp(); if (!forOp) return failure(); - auto opResult = dimOp.getSource().template cast(); + auto opResult = cast(dimOp.getSource()); unsigned resultNumber = opResult.getResultNumber(); if (!isShapePreserving(forOp, resultNumber)) return failure(); diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -164,8 +164,7 @@ clone->walk([&](Operation *nested) { for (OpOperand &operand : nested->getOpOperands()) { Operation *def = operand.get().getDefiningOp(); - if ((def && !clone->isAncestor(def)) || - operand.get().isa()) + if ((def && !clone->isAncestor(def)) || isa(operand.get())) callback(&operand); } }); @@ -346,7 +345,7 @@ rewriter.setInsertionPointAfter(newOp); continue; } - auto arg = operand->get().dyn_cast(); + auto arg = dyn_cast(operand->get()); if (arg && arg.getOwner() == forOp.getBody()) { // If the value is a loop carried value coming from stage N + 1 remap, // it will become a direct use. diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -496,7 +496,7 @@ ArrayRef loops) { std::optional destinationIterArg; auto loopIt = loops.rbegin(); - while (auto iterArg = source->get().dyn_cast()) { + while (auto iterArg = dyn_cast(source->get())) { scf::ForOp loop = *loopIt; if (iterArg.getOwner()->getParentOp() != loop) break; @@ -505,7 +505,7 @@ } if (loopIt == loops.rend()) destinationIterArg = source; - return {source->get().dyn_cast(), destinationIterArg}; + return {dyn_cast(source->get()), destinationIterArg}; } /// Implementation of fusing producer of a single slice by computing the diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp @@ -42,8 +42,8 @@ PatternRewriter &rewriter) const override { SmallVector globalVarAttrs; - auto ptrType = op.getType().cast(); - auto pointeeType = ptrType.getPointeeType().cast(); + auto ptrType = cast(op.getType()); + auto pointeeType = cast(ptrType.getPointeeType()); spirv::StructType structType = VulkanLayoutUtils::decorateType(pointeeType); if (!structType) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -51,19 +51,19 @@ // info create a variable of type !spirv.ptr>. If // not it must already be a !spirv.ptr>. auto varType = funcOp.getFunctionType().getInput(argIndex); - if (varType.cast().isScalarOrVector()) { + if (cast(varType).isScalarOrVector()) { auto storageClass = abiInfo.getStorageClass(); if (!storageClass) return nullptr; varType = spirv::PointerType::get(spirv::StructType::get(varType), *storageClass); } - auto varPtrType = varType.cast(); - auto varPointeeType = varPtrType.getPointeeType().cast(); + auto varPtrType = cast(varType); + auto varPointeeType = cast(varPtrType.getPointeeType()); // Set the offset information. varPointeeType = - VulkanLayoutUtils::decorateType(varPointeeType).cast(); + cast(VulkanLayoutUtils::decorateType(varPointeeType)); if (!varPointeeType) return nullptr; @@ -98,7 +98,7 @@ // Starting with version 1.4, the interface’s storage classes are all // storage classes used in declaring all global variables referenced by the // entry point’s call tree." We should consider the target environment here. - switch (var.getType().cast().getStorageClass()) { + switch (cast(var.getType()).getStorageClass()) { case spirv::StorageClass::Input: case spirv::StorageClass::Output: interfaceVarSet.insert(var.getOperation()); @@ -247,7 +247,7 @@ // at the start of the function. It is probably better to do the load just // before the use. There might be multiple loads and currently there is no // easy way to replace all uses with a sequence of operations. - if (argType.value().cast().isScalarOrVector()) { + if (cast(argType.value()).isScalarOrVector()) { auto zero = spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter); auto loadPtr = rewriter.create( @@ -287,7 +287,7 @@ typeConverter.addSourceMaterialization([](OpBuilder &builder, spirv::PointerType type, ValueRange inputs, Location loc) { - if (inputs.size() != 1 || !inputs[0].getType().isa()) + if (inputs.size() != 1 || !isa(inputs[0].getType())) return Value(); return builder.create(loc, type, inputs[0]).getResult(); }); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp @@ -84,15 +84,13 @@ LogicalResult RewriteInsertsPass::collectInsertionChain( spirv::CompositeInsertOp op, SmallVectorImpl &insertions) { - auto indicesArrayAttr = op.getIndices().cast(); + auto indicesArrayAttr = cast(op.getIndices()); // TODO: handle nested composite object. if (indicesArrayAttr.size() == 1) { - auto numElements = op.getComposite() - .getType() - .cast() + auto numElements = cast(op.getComposite().getType()) .getNumElements(); - auto index = indicesArrayAttr[0].cast().getInt(); + auto index = cast(indicesArrayAttr[0]).getInt(); // Need a last index to collect a sequential chain. if (index + 1 != numElements) return failure(); @@ -109,9 +107,9 @@ return failure(); --index; - indicesArrayAttr = op.getIndices().cast(); + indicesArrayAttr = cast(op.getIndices()); if ((indicesArrayAttr.size() != 1) || - (indicesArrayAttr[0].cast().getInt() != index)) + (cast(indicesArrayAttr[0]).getInt() != index)) return failure(); } } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -138,7 +138,7 @@ // SPIR-V dialect. Keeping it local till the use case arises. static std::optional getTypeNumBytes(const SPIRVConversionOptions &options, Type type) { - if (type.isa()) { + if (isa(type)) { auto bitWidth = type.getIntOrFloatBitWidth(); // According to the SPIR-V spec: // "There is no physical size or bit pattern defined for values with boolean @@ -151,21 +151,21 @@ return bitWidth / 8; } - if (auto complexType = type.dyn_cast()) { + if (auto complexType = dyn_cast(type)) { auto elementSize = getTypeNumBytes(options, complexType.getElementType()); if (!elementSize) return std::nullopt; return 2 * *elementSize; } - if (auto vecType = type.dyn_cast()) { + if (auto vecType = dyn_cast(type)) { auto elementSize = getTypeNumBytes(options, vecType.getElementType()); if (!elementSize) return std::nullopt; return vecType.getNumElements() * *elementSize; } - if (auto memRefType = type.dyn_cast()) { + if (auto memRefType = dyn_cast(type)) { // TODO: Layout should also be controlled by the ABI attributes. For now // using the layout from MemRef. int64_t offset; @@ -197,7 +197,7 @@ return (offset + memrefSize) * *elementSize; } - if (auto tensorType = type.dyn_cast()) { + if (auto tensorType = dyn_cast(type)) { if (!tensorType.hasStaticShape()) return std::nullopt; @@ -245,12 +245,12 @@ return nullptr; } - if (auto floatType = type.dyn_cast()) { + if (auto floatType = dyn_cast(type)) { LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); return Builder(targetEnv.getContext()).getF32Type(); } - auto intType = type.cast(); + auto intType = cast(type); LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); return IntegerType::get(targetEnv.getContext(), /*width=*/32, intType.getSignedness()); @@ -293,8 +293,8 @@ // Get extension and capability requirements for the given type. SmallVector, 1> extensions; SmallVector, 2> capabilities; - type.cast().getExtensions(extensions, storageClass); - type.cast().getCapabilities(capabilities, storageClass); + cast(type).getExtensions(extensions, storageClass); + cast(type).getCapabilities(capabilities, storageClass); // If all requirements are met, then we can accept this type as-is. if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && @@ -389,8 +389,8 @@ << "using non-8-bit storage for bool types unimplemented"); return nullptr; } - auto elementType = IntegerType::get(type.getContext(), numBoolBits) - .dyn_cast(); + auto elementType = dyn_cast( + IntegerType::get(type.getContext(), numBoolBits)); if (!elementType) return nullptr; Type arrayElemType = @@ -429,7 +429,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, MemRefType type) { - auto attr = type.getMemorySpace().dyn_cast_or_null(); + auto attr = dyn_cast_or_null(type.getMemorySpace()); if (!attr) { LLVM_DEBUG( llvm::dbgs() @@ -441,24 +441,24 @@ } spirv::StorageClass storageClass = attr.getValue(); - if (type.getElementType().isa() && + if (isa(type.getElementType()) && type.getElementTypeBitWidth() == 1) { return convertBoolMemrefType(targetEnv, options, type, storageClass); } Type arrayElemType; Type elementType = type.getElementType(); - if (auto vecType = elementType.dyn_cast()) { + if (auto vecType = dyn_cast(elementType)) { arrayElemType = convertVectorType(targetEnv, options, vecType, storageClass); - } else if (auto complexType = elementType.dyn_cast()) { + } else if (auto complexType = dyn_cast(elementType)) { arrayElemType = convertComplexType(targetEnv, options, complexType, storageClass); - } else if (auto scalarType = elementType.dyn_cast()) { + } else if (auto scalarType = dyn_cast(elementType)) { arrayElemType = convertScalarType(targetEnv, options, scalarType, storageClass); - } else if (auto indexType = elementType.dyn_cast()) { - type = convertIndexElementType(type, options).cast(); + } else if (auto indexType = dyn_cast(elementType)) { + type = cast(convertIndexElementType(type, options)); arrayElemType = type.getElementType(); } else { LLVM_DEBUG( @@ -523,13 +523,13 @@ addConversion([this](IndexType /*indexType*/) { return getIndexType(); }); addConversion([this](IntegerType intType) -> std::optional { - if (auto scalarType = intType.dyn_cast()) + if (auto scalarType = dyn_cast(intType)) return convertScalarType(this->targetEnv, this->options, scalarType); return Type(); }); addConversion([this](FloatType floatType) -> std::optional { - if (auto scalarType = floatType.dyn_cast()) + if (auto scalarType = dyn_cast(floatType)) return convertScalarType(this->targetEnv, this->options, scalarType); return Type(); }); @@ -722,7 +722,7 @@ static spirv::GlobalVariableOp getPushConstantVariable(Block &body, unsigned elementCount) { for (auto varOp : body.getOps()) { - auto ptrType = varOp.getType().dyn_cast(); + auto ptrType = dyn_cast(varOp.getType()); if (!ptrType) continue; @@ -730,11 +730,11 @@ // block statically used per shader entry point." So we should always reuse // the existing one. if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) { - auto numElements = ptrType.getPointeeType() - .cast() - .getElementType(0) - .cast() - .getNumElements(); + auto numElements = + cast( + ptrType.getPointeeType().cast().getElementType( + 0)) + .getNumElements(); if (numElements == elementCount) return varOp; } @@ -864,8 +864,8 @@ linearizeIndex(indices, strides, offset, indexType, loc, builder); } Type pointeeType = - basePtr.getType().cast().getPointeeType(); - if (pointeeType.isa()) { + cast(basePtr.getType()).getPointeeType(); + if (isa(pointeeType)) { linearizedIndices.push_back(linearIndex); return builder.create(loc, basePtr, linearizedIndices); @@ -953,7 +953,7 @@ // Ensure that all types have been converted to SPIRV types. if (llvm::any_of(valueTypes, - [](Type t) { return !t.isa(); })) + [](Type t) { return !isa(t); })) return false; // Special treatment for global variables, whose type requirements are @@ -967,13 +967,13 @@ SmallVector, 8> typeCapabilities; for (Type valueType : valueTypes) { typeExtensions.clear(); - valueType.cast().getExtensions(typeExtensions); + cast(valueType).getExtensions(typeExtensions); if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, typeExtensions))) return false; typeCapabilities.clear(); - valueType.cast().getCapabilities(typeCapabilities); + cast(valueType).getCapabilities(typeCapabilities); if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, typeCapabilities))) return false; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp @@ -41,7 +41,7 @@ //===----------------------------------------------------------------------===// Attribute getScalarOrSplatAttr(Type type, int64_t value) { APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value); - if (auto intTy = type.dyn_cast()) + if (auto intTy = dyn_cast(type)) return IntegerAttr::get(intTy, sizedValue); return SplatElementsAttr::get(cast(type), sizedValue); @@ -149,7 +149,7 @@ // Currently, WGSL only supports 32-bit integer types. Any other integer // types should already have been promoted/demoted to i32. - auto elemTy = getElementTypeOrSelf(lhs.getType()).cast(); + auto elemTy = cast(getElementTypeOrSelf(lhs.getType())); if (elemTy.getIntOrFloatBitWidth() != 32) return rewriter.notifyMatchFailure( loc, diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp @@ -65,16 +65,16 @@ /// `!spirv.ptr>>`. Returns null type /// otherwise. static Type getRuntimeArrayElementType(Type type) { - auto ptrType = type.dyn_cast(); + auto ptrType = dyn_cast(type); if (!ptrType) return {}; - auto structType = ptrType.getPointeeType().dyn_cast(); + auto structType = dyn_cast(ptrType.getPointeeType()); if (!structType || structType.getNumElements() != 1) return {}; auto rtArrayType = - structType.getElementType(0).dyn_cast(); + dyn_cast(structType.getElementType(0)); if (!rtArrayType) return {}; @@ -97,7 +97,7 @@ for (const auto &indexedTypes : llvm::enumerate(types)) { spirv::SPIRVType type = indexedTypes.value(); assert(type.isScalarOrVector()); - if (auto vectorType = type.dyn_cast()) { + if (auto vectorType = dyn_cast(type)) { if (vectorType.getNumElements() % 2 != 0) return std::nullopt; // Odd-sized vector has special layout // requirements. @@ -277,7 +277,7 @@ if (!elementType) return; // Unexpected resource variable type. - auto type = elementType.cast(); + auto type = cast(elementType); if (!type.isScalarOrVector()) return; // Unexpected resource element type. @@ -370,7 +370,7 @@ Location loc = acOp.getLoc(); - if (srcElemType.isIntOrFloat() && dstElemType.isa()) { + if (srcElemType.isIntOrFloat() && isa(dstElemType)) { // The source indices are for a buffer with scalar element types. Rewrite // them into a buffer with vector element types. We need to scale the last // index for the vector as a whole, then add one level of index for inside @@ -398,7 +398,7 @@ } if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) || - (srcElemType.isa() && dstElemType.isa())) { + (isa(srcElemType) && isa(dstElemType))) { // The source indices are for a buffer with larger bitwidth scalar/vector // element types. Rewrite them into a buffer with smaller bitwidth element // types. We only need to scale the last index. @@ -433,10 +433,10 @@ LogicalResult matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto srcPtrType = loadOp.getPtr().getType().cast(); - auto srcElemType = srcPtrType.getPointeeType().cast(); - auto dstPtrType = adaptor.getPtr().getType().cast(); - auto dstElemType = dstPtrType.getPointeeType().cast(); + auto srcPtrType = cast(loadOp.getPtr().getType()); + auto srcElemType = cast(srcPtrType.getPointeeType()); + auto dstPtrType = cast(adaptor.getPtr().getType()); + auto dstElemType = cast(dstPtrType.getPointeeType()); Location loc = loadOp.getLoc(); auto newLoadOp = rewriter.create(loc, adaptor.getPtr()); @@ -454,7 +454,7 @@ } if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) || - (srcElemType.isa() && dstElemType.isa())) { + (isa(srcElemType) && isa(dstElemType))) { // The source and destination have scalar types of different bitwidths, or // vector types of different component counts. For such cases, we load // multiple smaller bitwidth values and construct a larger bitwidth one. @@ -495,13 +495,13 @@ // type. Type vectorType = srcElemType; - if (!srcElemType.isa()) + if (!isa(srcElemType)) vectorType = VectorType::get({ratio}, dstElemType); // If both the source and destination are vector types, we need to make // sure the scalar type is the same for composite construction later. - if (auto srcElemVecType = srcElemType.dyn_cast()) - if (auto dstElemVecType = dstElemType.dyn_cast()) { + if (auto srcElemVecType = dyn_cast(srcElemType)) + if (auto dstElemVecType = dyn_cast(dstElemType)) { if (srcElemVecType.getElementType() != dstElemVecType.getElementType()) { int64_t count = @@ -515,7 +515,7 @@ Value vectorValue = rewriter.create( loc, vectorType, components); - if (!srcElemType.isa()) + if (!isa(srcElemType)) vectorValue = rewriter.create(loc, srcElemType, vectorValue); rewriter.replaceOp(loadOp, vectorValue); @@ -534,9 +534,9 @@ matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcElemType = - storeOp.getPtr().getType().cast().getPointeeType(); + cast(storeOp.getPtr().getType()).getPointeeType(); auto dstElemType = - adaptor.getPtr().getType().cast().getPointeeType(); + cast(adaptor.getPtr().getType()).getPointeeType(); if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat()) return rewriter.notifyMatchFailure(storeOp, "not scalar type"); if (!areSameBitwidthScalarType(srcElemType, dstElemType)) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -159,13 +159,13 @@ SmallVector, 8> typeCapabilities; for (Type valueType : valueTypes) { typeExtensions.clear(); - valueType.cast().getExtensions(typeExtensions); + cast(valueType).getExtensions(typeExtensions); if (failed(checkAndUpdateExtensionRequirements( op, targetEnv, typeExtensions, deducedExtensions))) return WalkResult::interrupt(); typeCapabilities.clear(); - valueType.cast().getCapabilities(typeCapabilities); + cast(valueType).getCapabilities(typeCapabilities); if (failed(checkAndUpdateCapabilityRequirements( op, targetEnv, typeCapabilities, deducedCapabilities))) return WalkResult::interrupt(); diff --git a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp --- a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp +++ b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp @@ -53,7 +53,7 @@ // must be a runtime array. assert(memberSize != std::numeric_limits().max() || (i + 1 == e && - structType.getElementType(i).isa())); + isa(structType.getElementType(i)))); // According to the Vulkan spec: // "A structure has a base alignment equal to the largest base alignment of // any of its members." @@ -79,23 +79,23 @@ Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size, VulkanLayoutUtils::Size &alignment) { - if (type.isa()) { + if (isa(type)) { alignment = getScalarTypeAlignment(type); // Vulkan spec does not specify any padding for a scalar type. size = alignment; return type; } - if (auto structType = type.dyn_cast()) + if (auto structType = dyn_cast(type)) return decorateType(structType, size, alignment); - if (auto arrayType = type.dyn_cast()) + if (auto arrayType = dyn_cast(type)) return decorateType(arrayType, size, alignment); - if (auto vectorType = type.dyn_cast()) + if (auto vectorType = dyn_cast(type)) return decorateType(vectorType, size, alignment); - if (auto arrayType = type.dyn_cast()) { + if (auto arrayType = dyn_cast(type)) { size = std::numeric_limits().max(); return decorateType(arrayType, alignment); } - if (type.isa()) { + if (isa(type)) { // TODO: Add support for `PhysicalStorageBufferAddresses`. return nullptr; } @@ -161,13 +161,13 @@ } bool VulkanLayoutUtils::isLegalType(Type type) { - auto ptrType = type.dyn_cast(); + auto ptrType = dyn_cast(type); if (!ptrType) { return true; } auto storageClass = ptrType.getStorageClass(); - auto structType = ptrType.getPointeeType().dyn_cast(); + auto structType = dyn_cast(ptrType.getPointeeType()); if (!structType) { return true; } diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -64,7 +64,7 @@ rewriter.setInsertionPointAfter(newOp); SmallVector newResults; for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) { - if (it.value().isa()) { + if (isa(it.value())) { newResults.push_back(rewriter.create( assumingOp.getLoc(), newOp->getResult(it.index()))); } else { @@ -116,7 +116,7 @@ auto yieldOp = cast(op); SmallVector newResults; for (Value value : yieldOp.getOperands()) { - if (value.getType().isa()) { + if (isa(value.getType())) { FailureOr buffer = getBuffer(rewriter, value, options); if (failed(buffer)) return failure(); diff --git a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp --- a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp @@ -133,7 +133,7 @@ for (shape::WithOp withOp : allWithOps) { Value value = withOp.getOperand(); Value shape = withOp.getShape(); - RankedTensorType rankedType = value.getType().dyn_cast(); + RankedTensorType rankedType = dyn_cast(value.getType()); if (rankedType == nullptr) continue; diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -41,7 +41,7 @@ options.unknownTypeConverterFn = [](Value value, Attribute memorySpace, const BufferizationOptions &options) { return getMemRefTypeWithStaticIdentityLayout( - value.getType().cast(), memorySpace); + cast(value.getType()), memorySpace); }; if (analysisOnly) { options.testAnalysisOnly = true; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -28,7 +28,7 @@ static std::optional> genSplitSparseConstant(OpBuilder &builder, Location loc, Value tensor) { if (auto constOp = tensor.getDefiningOp()) { - if (auto a = constOp.getValue().dyn_cast()) { + if (auto a = dyn_cast(constOp.getValue())) { auto coordinates = builder.create(loc, a.getIndices()); auto values = builder.create(loc, a.getValues()); return std::make_pair(coordinates, values); @@ -94,7 +94,7 @@ OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) { if (tp.isIndex()) return OverheadType::kIndex; - if (auto intTp = tp.dyn_cast()) + if (auto intTp = dyn_cast(tp)) return overheadTypeEncoding(intTp.getWidth()); llvm_unreachable("Unknown overhead type"); } @@ -169,7 +169,7 @@ return PrimaryType::kI16; if (elemTp.isInteger(8)) return PrimaryType::kI8; - if (auto complexTp = elemTp.dyn_cast()) { + if (auto complexTp = dyn_cast(elemTp)) { auto complexEltTp = complexTp.getElementType(); if (complexEltTp.isF64()) return PrimaryType::kC64; @@ -205,10 +205,10 @@ return value; // int <=> index - if (srcTp.isa() || dstTp.isa()) + if (isa(srcTp) || isa(dstTp)) return builder.create(loc, dstTp, value); - const auto srcIntTp = srcTp.dyn_cast_or_null(); + const auto srcIntTp = dyn_cast_or_null(srcTp); const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false; return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast); } @@ -216,7 +216,7 @@ Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem, Value s) { Value load = builder.create(loc, mem, s); - if (!load.getType().isa()) { + if (!isa(load.getType())) { if (load.getType().getIntOrFloatBitWidth() < 64) load = builder.create(loc, builder.getI64Type(), load); load = @@ -226,14 +226,14 @@ } mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { - if (tp.isa()) + if (isa(tp)) return builder.getFloatAttr(tp, 1.0); - if (tp.isa()) + if (isa(tp)) return builder.getIndexAttr(1); - if (auto intTp = tp.dyn_cast()) + if (auto intTp = dyn_cast(tp)) return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1)); - if (tp.isa()) { - auto shapedTp = tp.cast(); + if (isa(tp)) { + auto shapedTp = cast(tp); if (auto one = getOneAttr(builder, shapedTp.getElementType())) return DenseElementsAttr::get(shapedTp, one); } @@ -244,13 +244,13 @@ Value v) { Type tp = v.getType(); Value zero = constantZero(builder, loc, tp); - if (tp.isa()) + if (isa(tp)) return builder.create(loc, arith::CmpFPredicate::UNE, v, zero); if (tp.isIntOrIndex()) return builder.create(loc, arith::CmpIPredicate::ne, v, zero); - if (tp.dyn_cast()) + if (dyn_cast(tp)) return builder.create(loc, v, zero); llvm_unreachable("Non-numeric type"); } @@ -580,12 +580,12 @@ } // Remap value. Value val; - if (attr.getElementType().isa()) { - auto valAttr = elems[i].second.cast(); + if (isa(attr.getElementType())) { + auto valAttr = cast(elems[i].second); val = builder.create(loc, attr.getElementType(), valAttr); } else { - auto valAttr = elems[i].second.cast(); + auto valAttr = cast(elems[i].second); val = builder.create(loc, valAttr); } assert(val); @@ -597,7 +597,7 @@ size_t size, Value mem, size_t offsetIdx, Value offsetVal) { #ifndef NDEBUG - const auto memTp = mem.getType().cast(); + const auto memTp = cast(mem.getType()); assert(memTp.getRank() == 1); const DynSize memSh = memTp.getDimSize(0); assert(ShapedType::isDynamic(memSh) || memSh >= static_cast(size)); @@ -619,7 +619,7 @@ ValueRange vs, size_t offsetIdx, Value offsetVal) { #ifndef NDEBUG const size_t vsize = vs.size(); - const auto memTp = mem.getType().cast(); + const auto memTp = cast(mem.getType()); assert(memTp.getRank() == 1); const DynSize memSh = memTp.getDimSize(0); assert(ShapedType::isDynamic(memSh) || memSh >= static_cast(vsize)); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -350,7 +350,7 @@ // on positions. for (TensorId t = 0, numTensors = getNumTensors(); t < numTensors; t++) { const Value tensor = tensors[t]; - const auto rtp = tensor.getType().dyn_cast(); + const auto rtp = dyn_cast(tensor.getType()); if (!rtp) // Skips only scalar, zero ranked tensor still need to be bufferized and // (probably) filled with zeros by users. @@ -432,7 +432,7 @@ Type indexType = builder.getIndexType(); Value c0 = constantZero(builder, loc, indexType); for (TensorId t = 0, e = tensors.size(); t < e; t++) { - auto rtp = tensors[t].getType().dyn_cast(); + auto rtp = dyn_cast(tensors[t].getType()); if (!rtp) continue; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -100,7 +100,7 @@ /// completion. Needs to cast the buffer to a unranked buffer. static Value genHostRegisterMemref(OpBuilder &builder, Location loc, Value mem) { - MemRefType memTp = mem.getType().cast(); + MemRefType memTp = cast(mem.getType()); UnrankedMemRefType resTp = UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0); Value cast = builder.create(loc, resTp, mem); @@ -133,7 +133,7 @@ /// that feature does not seem to be fully supported yet. static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem, Value token) { - auto tp = mem.getType().cast(); + auto tp = cast(mem.getType()); auto elemTp = tp.getElementType(); auto shape = tp.getShape(); auto memTp = MemRefType::get(shape, elemTp); @@ -304,7 +304,7 @@ for (OpOperand &o : op->getOpOperands()) { Value val = o.get(); Block *block; - if (auto arg = val.dyn_cast()) + if (auto arg = dyn_cast(val)) block = arg.getOwner(); else block = val.getDefiningOp()->getBlock(); @@ -321,7 +321,7 @@ Type tp = val.getType(); if (val.getDefiningOp()) constants.push_back(val); - else if (tp.isa() || tp.isIntOrIndex()) + else if (isa(tp) || tp.isIntOrIndex()) scalars.push_back(val); else if (isa(tp)) buffers.push_back(val); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp @@ -111,9 +111,9 @@ Value metaData = builder.create(loc, structType); SpecifierStructBuilder md(metaData); if (!source) { - auto memSizeArrayType = structType.cast() - .getBody()[kMemSizePosInSpecifier] - .cast(); + auto memSizeArrayType = + cast(structType.cast() + .getBody()[kMemSizePosInSpecifier]); Value zero = constantZero(builder, loc, memSizeArrayType.getElementType()); // Fill memSizes array with zero. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -80,7 +80,7 @@ Value idx) { idx = genCast(builder, loc, idx, builder.getIndexType()); val = genCast(builder, loc, val, - mem.getType().cast().getElementType()); + cast(mem.getType()).getElementType()); builder.create(loc, val, mem, idx); } @@ -253,7 +253,7 @@ case SparseTensorFieldKind::CrdMemRef: case SparseTensorFieldKind::ValMemRef: field = createAllocation( - builder, loc, fType.cast(), + builder, loc, cast(fType), (fKind == SparseTensorFieldKind::PosMemRef) ? posHeuristic : (fKind == SparseTensorFieldKind::CrdMemRef) ? crdHeuristic : valHeuristic, @@ -779,7 +779,7 @@ fields.reserve(desc.getNumFields()); // Memcpy on memref fields. for (auto field : desc.getMemRefFields()) { - auto memrefTp = field.getType().cast(); + auto memrefTp = cast(field.getType()); auto size = rewriter.create(loc, field, 0); auto copied = rewriter.create(loc, memrefTp, ValueRange{size}); @@ -1128,7 +1128,7 @@ auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource()); SmallVector fields; foreachFieldAndTypeInSparseTensor( - SparseTensorType(op.getResult().getType().cast()), + SparseTensorType(cast(op.getResult().getType())), [&rewriter, &fields, srcDesc, loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl, DimLevelType /*dlt*/) -> bool { @@ -1143,7 +1143,7 @@ // values. Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0); auto dstMem = rewriter.create( - loc, fTp.cast(), sz); + loc, cast(fTp), sz); if (fTp != srcMem.getType()) { // Converts elements type. scf::buildLoopNest( @@ -1397,7 +1397,7 @@ } assert(field); - if (auto memrefTp = field.getType().dyn_cast(); + if (auto memrefTp = dyn_cast(field.getType()); memrefTp && memrefTp.getRank() > 1) { ReassociationIndices reassociation; for (int i = 0, e = memrefTp.getRank(); i < e; i++) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -399,7 +399,7 @@ /// (which can be either dim- or lvl-coords, depending on context). static Value genGetNextCall(OpBuilder &builder, Location loc, Value iter, Value coords, Value elemPtr) { - Type elemTp = elemPtr.getType().cast().getElementType(); + Type elemTp = cast(elemPtr.getType()).getElementType(); SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)}; SmallVector params{iter, coords, elemPtr}; Type i1 = builder.getI1Type(); @@ -1045,7 +1045,7 @@ matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resTp = op.getType(); - Type posTp = resTp.cast().getElementType(); + Type posTp = cast(resTp).getElementType(); SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)}; Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel()); replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl}, @@ -1064,7 +1064,7 @@ ConversionPatternRewriter &rewriter) const override { // TODO: use `SparseTensorType::getCrdType` instead. Type resType = op.getType(); - const Type crdTp = resType.cast().getElementType(); + const Type crdTp = cast(resType).getElementType(); SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)}; Location loc = op->getLoc(); @@ -1096,7 +1096,7 @@ LogicalResult matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto resType = op.getType().cast(); + auto resType = cast(op.getType()); rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType, adaptor.getOperands())); return success(); @@ -1113,7 +1113,7 @@ ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); // Query values array size for the actually stored values size. - Type eltType = op.getTensor().getType().cast().getElementType(); + Type eltType = cast(op.getTensor().getType()).getElementType(); auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType); Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands()); rewriter.replaceOpWithNewOp(op, values, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -79,7 +79,7 @@ // Helper to detect chain of multiplications that do not involve x. static bool isMulChain(Value val, Value x) { - if (auto arg = val.dyn_cast()) + if (auto arg = dyn_cast(val)) return arg != x; if (auto *def = val.getDefiningOp()) { if (isa(def) || isa(def)) @@ -105,7 +105,7 @@ // Helper to detect direct yield of a zero value. static bool isZeroYield(GenericOp op) { auto yieldOp = cast(op.getRegion().front().getTerminator()); - if (auto arg = yieldOp.getOperand(0).dyn_cast()) { + if (auto arg = dyn_cast(yieldOp.getOperand(0))) { if (arg.getOwner()->getParentOp() == op) { return isZeroValue(op->getOperand(arg.getArgNumber())); } @@ -719,7 +719,7 @@ bool fromSparseConst = false; if (auto constOp = op.getSource().getDefiningOp()) { - if (constOp.getValue().dyn_cast()) { + if (dyn_cast(constOp.getValue())) { fromSparseConst = true; } } @@ -974,7 +974,7 @@ // Special-case: for each over a sparse constant uses its own rewriting // rule. if (auto constOp = input.getDefiningOp()) { - if (auto attr = constOp.getValue().dyn_cast()) { + if (auto attr = dyn_cast(constOp.getValue())) { return genForeachOnSparseConstant(op, rewriter, attr); } } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -88,9 +88,9 @@ // Overrides method from AffineExprVisitor. void visitDimExpr(AffineDimExpr expr) { if (pickedDim == nullptr || - pickIterType == iterTypes[expr.getPosition()] - .cast() - .getValue()) { + pickIterType == + cast(iterTypes[expr.getPosition()]) + .getValue()) { pickedDim = expr; } } @@ -344,7 +344,7 @@ // we can't use `getRankedTensorType`/`getSparseTensorType` here. // However, we don't need to handle `StorageSpecifierType`, so we // can use `SparseTensorType` once we guard against non-tensors. - const auto rtp = tensor.getType().dyn_cast(); + const auto rtp = dyn_cast(tensor.getType()); if (!rtp) return 0; const SparseTensorType stt(rtp); @@ -1243,7 +1243,7 @@ Location loc = op.getLoc(); if (atStart) { auto dynShape = {ShapedType::kDynamic}; - Type etp = tensor.getType().cast().getElementType(); + Type etp = cast(tensor.getType()).getElementType(); Type t1 = MemRefType::get(dynShape, etp); Type t2 = MemRefType::get(dynShape, builder.getI1Type()); Type t3 = MemRefType::get(dynShape, builder.getIndexType()); @@ -1833,7 +1833,7 @@ // required for sparse tensor slice rank reducing too. Level maxLvlRank = 0; for (auto operand : op.getOperands()) { - if (auto rtp = operand.getType().dyn_cast()) { + if (auto rtp = dyn_cast(operand.getType())) { maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank()); } } diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -1061,8 +1061,8 @@ if (expr.kind == TensorExp::Kind::kInvariant) { if (auto c = expr.val.getDefiningOp()) { ArrayAttr arrayAttr = c.getValue(); - return arrayAttr[0].cast().getValue().isZero() && - arrayAttr[1].cast().getValue().isZero(); + return cast(arrayAttr[0]).getValue().isZero() && + cast(arrayAttr[1]).getValue().isZero(); } if (auto c = expr.val.getDefiningOp()) return c.value() == 0; @@ -1077,7 +1077,7 @@ Type dtp = exp(e).val.getType(); // Inspect source type. For vector types, apply the same // vectorization to the destination type. - if (auto vtp = src.getType().dyn_cast()) + if (auto vtp = dyn_cast(src.getType())) return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims()); return dtp; } @@ -1085,7 +1085,7 @@ /// Ensures that sparse compiler can generate code for expression. static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) { // Arguments are always admissible. - if (v.isa()) + if (isa(v)) return true; // Accept index anywhere. Operation *def = v.getDefiningOp(); @@ -1113,7 +1113,7 @@ } std::optional Merger::buildTensorExp(linalg::GenericOp op, Value v) { - if (auto arg = v.dyn_cast()) { + if (auto arg = dyn_cast(v)) { const TensorId tid = makeTensorId(arg.getArgNumber()); // Any argument of the generic op that is not marked as a scalar // argument is considered a tensor, indexed by the implicit loop @@ -1346,8 +1346,8 @@ case TensorExp::Kind::kAbsF: return rewriter.create(loc, v0); case TensorExp::Kind::kAbsC: { - auto type = v0.getType().cast(); - auto eltType = type.getElementType().cast(); + auto type = cast(v0.getType()); + auto eltType = cast(type.getElementType()); return rewriter.create(loc, eltType, v0); } case TensorExp::Kind::kAbsI: @@ -1407,13 +1407,13 @@ case TensorExp::Kind::kTruncI: return rewriter.create(loc, inferType(e, v0), v0); case TensorExp::Kind::kCIm: { - auto type = v0.getType().cast(); - auto eltType = type.getElementType().cast(); + auto type = cast(v0.getType()); + auto eltType = cast(type.getElementType()); return rewriter.create(loc, eltType, v0); } case TensorExp::Kind::kCRe: { - auto type = v0.getType().cast(); - auto eltType = type.getElementType().cast(); + auto type = cast(v0.getType()); + auto eltType = cast(type.getElementType()); return rewriter.create(loc, eltType, v0); } case TensorExp::Kind::kBitCast: diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -60,20 +60,20 @@ // type in case the input is an unranked tensor type. // Case 1: Casting an unranked tensor - if (castOp.getSource().getType().isa()) { + if (isa(castOp.getSource().getType())) { // When casting to a ranked tensor, we cannot infer any static offset or // strides from the source. Assume fully dynamic. return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace); } // Case 2: Casting to an unranked tensor type - if (castOp.getType().isa()) { + if (isa(castOp.getType())) { return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace); } // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not // change. - auto rankedResultType = castOp.getType().cast(); + auto rankedResultType = cast(castOp.getType()); return MemRefType::get( rankedResultType.getShape(), rankedResultType.getElementType(), maybeSrcBufferType->cast().getLayout(), memorySpace); @@ -158,7 +158,7 @@ if (failed(maybeBuffer)) return failure(); Value buffer = *maybeBuffer; - auto bufferType = buffer.getType().cast(); + auto bufferType = cast(buffer.getType()); if (tensorResultType.getRank() == 0) { // 0-d collapses must go through a different op builder. @@ -383,11 +383,9 @@ SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); SmallVector mixedSizes = extractSliceOp.getMixedSizes(); SmallVector mixedStrides = extractSliceOp.getMixedStrides(); - return memref::SubViewOp::inferRankReducedResultType( - extractSliceOp.getType().getShape(), - srcMemrefType->cast(), mixedOffsets, mixedSizes, - mixedStrides) - .cast(); + return cast(memref::SubViewOp::inferRankReducedResultType( + extractSliceOp.getType().getShape(), srcMemrefType->cast(), + mixedOffsets, mixedSizes, mixedStrides)); } }; @@ -459,7 +457,7 @@ auto fromElementsOp = cast(op); // Should the buffer be deallocated? bool dealloc = shouldDeallocateOpResult( - fromElementsOp.getResult().cast(), options); + cast(fromElementsOp.getResult()), options); // TODO: Implement memory space for this op. if (options.defaultMemorySpace != Attribute()) @@ -467,7 +465,7 @@ // Allocate a buffer for the result. Location loc = op->getLoc(); - auto tensorType = fromElementsOp.getType().cast(); + auto tensorType = cast(fromElementsOp.getType()); auto shape = tensorType.getShape(); // TODO: Create alloc_tensor ops during TensorCopyInsertion. FailureOr tensorAlloc = @@ -540,7 +538,7 @@ ValueRange dynamicSizes, Region &generateBody) { assert(generateBody.hasOneBlock() && "expected body with single block"); - auto tensorType = tensorDestination.getType().cast(); + auto tensorType = cast(tensorDestination.getType()); assert(generateBody.getNumArguments() == tensorType.getRank() && "rank mismatch"); @@ -579,7 +577,7 @@ auto generateOp = cast(op); // Should the buffer be deallocated? bool dealloc = shouldDeallocateOpResult( - generateOp.getResult().cast(), options); + cast(generateOp.getResult()), options); // TODO: Implement memory space for this op. if (options.defaultMemorySpace != Attribute()) @@ -800,12 +798,11 @@ return failure(); // Take a subview of the destination buffer. - auto dstMemrefType = dstMemref->getType().cast(); + auto dstMemrefType = cast(dstMemref->getType()); auto subviewMemRefType = - memref::SubViewOp::inferRankReducedResultType( + cast(memref::SubViewOp::inferRankReducedResultType( insertSliceOp.getSourceType().getShape(), dstMemrefType, - mixedOffsets, mixedSizes, mixedStrides) - .cast(); + mixedOffsets, mixedSizes, mixedStrides)); Value subView = rewriter.create( loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, mixedStrides); @@ -900,7 +897,7 @@ // Should the buffer be deallocated? bool dealloc = - shouldDeallocateOpResult(padOp.getResult().cast(), options); + shouldDeallocateOpResult(cast(padOp.getResult()), options); // Allocate a buffer for the padded result. FailureOr tensorAlloc = allocateTensorForShapedValue(rewriter, loc, padOp.getResult(), @@ -992,7 +989,7 @@ return failure(); auto resultMemRefType = getMemRefType( reshapeOp.getResult(), options, /*layout=*/{}, - srcBuffer->getType().cast().getMemorySpace()); + cast(srcBuffer->getType()).getMemorySpace()); replaceOpWithNewBufferizedOp( rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer); return success(); @@ -1039,14 +1036,13 @@ return failure(); // Take a subview of the destination buffer. - auto destBufferType = destBuffer->getType().cast(); + auto destBufferType = cast(destBuffer->getType()); auto subviewMemRefType = - memref::SubViewOp::inferRankReducedResultType( + cast(memref::SubViewOp::inferRankReducedResultType( parallelInsertSliceOp.getSourceType().getShape(), destBufferType, parallelInsertSliceOp.getMixedOffsets(), parallelInsertSliceOp.getMixedSizes(), - parallelInsertSliceOp.getMixedStrides()) - .cast(); + parallelInsertSliceOp.getMixedStrides())); Value subview = rewriter.create( parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer, parallelInsertSliceOp.getMixedOffsets(), diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp @@ -29,7 +29,7 @@ /// Get the dimension size of a value of RankedTensor type at the static OpFoldResult getShapeDimSize(OpBuilder &b, Location loc, Value rankedTensor, int64_t dimIdx) { - RankedTensorType tensorType = rankedTensor.getType().cast(); + RankedTensorType tensorType = cast(rankedTensor.getType()); if (!tensorType.isDynamicDim(dimIdx)) { return b.getIndexAttr(tensorType.getDimSize(dimIdx)); } @@ -41,7 +41,7 @@ static SmallVector getShapeDimSizes(OpBuilder &b, Location loc, Value rankedTensor) { SmallVector dimSizes; - RankedTensorType tensorType = rankedTensor.getType().cast(); + RankedTensorType tensorType = cast(rankedTensor.getType()); for (unsigned i = 0; i < tensorType.getRank(); i++) dimSizes.push_back(getShapeDimSize(b, loc, rankedTensor, i)); return dimSizes; diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -44,7 +44,7 @@ SmallVector mlir::tensor::createDynamicDimValues(OpBuilder &b, Location loc, Value rankedTensor) { - auto tensorTy = rankedTensor.getType().cast(); + auto tensorTy = cast(rankedTensor.getType()); SmallVector dynamicDims; for (const auto &en : llvm::enumerate(tensorTy.getShape())) { if (en.value() == ShapedType::kDynamic) @@ -57,7 +57,7 @@ FailureOr mlir::tensor::createDimValue(OpBuilder &b, Location loc, Value rankedTensor, int64_t dim) { - auto tensorTy = rankedTensor.getType().dyn_cast(); + auto tensorTy = dyn_cast(rankedTensor.getType()); if (!tensorTy) return failure(); auto shape = tensorTy.getShape(); @@ -70,7 +70,7 @@ SmallVector mlir::tensor::createDimValues(OpBuilder &b, Location loc, Value rankedTensor) { - auto tensorTy = rankedTensor.getType().cast(); + auto tensorTy = cast(rankedTensor.getType()); SmallVector dims; for (const auto &en : llvm::enumerate(tensorTy.getShape())) { if (ShapedType::isDynamic(en.value())) { diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp @@ -34,9 +34,9 @@ PatternRewriter &rewriter) const override { Value input = op.getInput(); Value weight = op.getWeight(); - ShapedType inputType = input.getType().cast(); - ShapedType weightType = weight.getType().cast(); - ShapedType resultType = op.getType().cast(); + ShapedType inputType = cast(input.getType()); + ShapedType weightType = cast(weight.getType()); + ShapedType resultType = cast(op.getType()); auto numDynamic = llvm::count_if(inputType.getShape(), ShapedType::isDynamic); @@ -66,7 +66,7 @@ auto quantizationInfo = op.getQuantizationInfo(); int64_t iZp = quantizationInfo->getInputZp(); - if (!validIntegerRange(inputETy.cast(), iZp)) + if (!validIntegerRange(cast(inputETy), iZp)) return rewriter.notifyMatchFailure( op, "tosa.conv op quantization has zp outside of input range"); @@ -116,7 +116,7 @@ weightShape[3]}; auto revisedWeightShapeType = RankedTensorType::get( revisedWeightShape, - weight.getType().dyn_cast().getElementType()); + dyn_cast(weight.getType()).getElementType()); auto reshapedWeight = rewriter .create( op.getLoc(), revisedWeightShapeType, weight, diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -28,9 +28,9 @@ PatternRewriter &rewriter) const override { Value input = op.getInput(); Value weight = op.getWeight(); - ShapedType inputType = input.getType().cast(); - ShapedType weightType = weight.getType().cast(); - ShapedType resultType = op.getOutput().getType().cast(); + ShapedType inputType = cast(input.getType()); + ShapedType weightType = cast(weight.getType()); + ShapedType resultType = cast(op.getOutput().getType()); if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && resultType.hasStaticShape())) { @@ -52,7 +52,7 @@ inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1}; inputType = RankedTensorType::get( revisedInputShape, - input.getType().dyn_cast().getElementType()); + dyn_cast(input.getType()).getElementType()); input = rewriter .create( op.getLoc(), inputType, input, @@ -76,7 +76,7 @@ auto applyZp = [&](Value val, int64_t zp) -> Value { if (zp == 0) return val; - auto ety = val.getType().cast().getElementType(); + auto ety = cast(val.getType()).getElementType(); auto zpTy = RankedTensorType::get({}, ety); auto zpAttr = DenseElementsAttr::get(zpTy, rewriter.getIntegerAttr(ety, zp)); @@ -126,17 +126,17 @@ inputType.getDimSize(2), inputType.getDimSize(3), weightShape[3]}; auto mulShapeType = RankedTensorType::get( mulShape, - weight.getType().dyn_cast().getElementType()); + dyn_cast(weight.getType()).getElementType()); Value mulValue = rewriter .create(op.getLoc(), mulShapeType, input, weight, /*shift=*/0) .getResult(); // Reshape output to [N, H, W, C * M]. - auto outputShape = op.getOutput().getType().cast().getShape(); + auto outputShape = cast(op.getOutput().getType()).getShape(); auto outputShapeType = RankedTensorType::get( outputShape, - input.getType().dyn_cast().getElementType()); + dyn_cast(input.getType()).getElementType()); auto outputValue = rewriter.create( op.getLoc(), outputShapeType, mulValue, rewriter.getDenseI64ArrayAttr(outputShape)); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -56,7 +56,7 @@ // Compute the knowledge based on the inferred type. auto inferredKnowledge = mlir::tosa::ValueKnowledge::getPessimisticValueState(); - inferredKnowledge.dtype = resultTy.cast().getElementType(); + inferredKnowledge.dtype = cast(resultTy).getElementType(); inferredKnowledge.hasRank = predictedShape.hasRank(); if (predictedShape.hasRank()) { for (auto dim : predictedShape.getDims()) { @@ -83,10 +83,10 @@ Value weight = op->getOperand(1); Value bias = op->getOperand(2); - ShapedType inputTy = input.getType().cast(); - ShapedType weightTy = weight.getType().cast(); - ShapedType biasTy = bias.getType().cast(); - ShapedType resultTy = op->getResult(0).getType().cast(); + ShapedType inputTy = cast(input.getType()); + ShapedType weightTy = cast(weight.getType()); + ShapedType biasTy = cast(bias.getType()); + ShapedType resultTy = cast(op->getResult(0).getType()); llvm::ArrayRef stride = op.getStride(); llvm::ArrayRef pad = op.getOutPad(); @@ -146,10 +146,10 @@ Value weight = op->getOperand(1); Value bias = op->getOperand(2); - ShapedType inputTy = input.getType().cast(); - ShapedType weightTy = weight.getType().cast(); - ShapedType biasTy = bias.getType().cast(); - ShapedType resultTy = op->getResult(0).getType().cast(); + ShapedType inputTy = cast(input.getType()); + ShapedType weightTy = cast(weight.getType()); + ShapedType biasTy = cast(bias.getType()); + ShapedType resultTy = cast(op->getResult(0).getType()); Type inputETy = inputTy.getElementType(); Type weightETy = weightTy.getElementType(); @@ -202,7 +202,7 @@ weight, weightPaddingVal); } - weightTy = weight.getType().cast(); + weightTy = cast(weight.getType()); weightHeight = weightTy.getDimSize(1); weightWidth = weightTy.getDimSize(2); @@ -231,7 +231,7 @@ weight = createOpAndInfer( rewriter, loc, UnrankedTensorType::get(weightETy), weight, rewriter.getDenseI64ArrayAttr(weightReshapeDims1)); - ShapedType restridedWeightTy = weight.getType().cast(); + ShapedType restridedWeightTy = cast(weight.getType()); weight = createOpAndInfer( rewriter, loc, UnrankedTensorType::get(weightETy), weight, @@ -297,7 +297,7 @@ } // Factor the resulting width / height. - ShapedType convTy = conv2d.getType().cast(); + ShapedType convTy = cast(conv2d.getType()); Type convETy = convTy.getElementType(); int64_t convHeight = convTy.getDimSize(1); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp @@ -72,7 +72,7 @@ auto baseType = inputType.getElementType(); // Handle possible integer types - if (auto intType = baseType.dyn_cast()) { + if (auto intType = dyn_cast(baseType)) { switch (intType.getWidth()) { case 1: return transposeType(attr, inputType, outputType, permValues); @@ -102,7 +102,7 @@ LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override { - auto outputType = op.getType().cast(); + auto outputType = cast(op.getType()); // TOSA supports quantized types. if (!outputType.getElementType().isIntOrIndexOrFloat()) return failure(); @@ -122,7 +122,7 @@ permAttr.getValues(), [](const APInt &val) { return val.getSExtValue(); })); - auto inputType = op.getInput1().getType().cast(); + auto inputType = cast(op.getInput1().getType()); auto resultAttr = transpose(inputValues, inputType, outputType, permValues); rewriter.replaceOpWithNewOp(op, outputType, resultAttr); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp @@ -54,7 +54,7 @@ for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) { auto inferredTy = shapesStorage[op.getOperand(i)]; auto blockArg = frontBlock.getArgument(i - 1); - auto oldType = blockArg.getType().cast(); + auto oldType = cast(blockArg.getType()); if (inferredTy.hasRank()) { Type newType = oldType.clone(inferredTy.getDims()); @@ -89,7 +89,7 @@ // loop body / condition for tosa.while. llvm::SmallVector argTypes; for (auto operand : op.getOperands()) { - auto operandTy = operand.getType().cast(); + auto operandTy = cast(operand.getType()); auto shapedTypeComponent = shapesStorage[operand]; if (shapedTypeComponent.hasRank()) { auto newTy = operandTy.clone(shapedTypeComponent.getDims()); @@ -188,7 +188,7 @@ void propagateShapesInRegion(Region ®ion) { DenseMap shapesStorage; auto setShapes = [&](Value val, Type t) { - if (auto st = t.dyn_cast()) + if (auto st = dyn_cast(t)) shapesStorage[val] = st; else shapesStorage[val] = t; @@ -247,8 +247,7 @@ // Compute the knowledge based on the inferred type. auto inferredKnowledge = ValueKnowledge::getPessimisticValueState(); - inferredKnowledge.dtype = - resultTy.cast().getElementType(); + inferredKnowledge.dtype = cast(resultTy).getElementType(); inferredKnowledge.hasRank = predictedShape.hasRank(); if (predictedShape.hasRank()) { for (auto dim : predictedShape.getDims()) { @@ -274,7 +273,7 @@ for (auto it : shapesStorage) { auto result = it.second; if (result.hasRank()) { - Type t = it.first.getType().cast().clone(result.getDims()); + Type t = cast(it.first.getType()).clone(result.getDims()); it.first.setType(t); } } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -82,8 +82,8 @@ Location loc, RankedTensorType outputType, Value &input1, Value &input2) { - auto input1Ty = input1.getType().dyn_cast(); - auto input2Ty = input2.getType().dyn_cast(); + auto input1Ty = dyn_cast(input1.getType()); + auto input2Ty = dyn_cast(input2.getType()); if (!input1Ty || !input2Ty) { return rewriter.notifyMatchFailure(loc, "input not a ranked tensor"); @@ -106,9 +106,9 @@ } ArrayRef higherRankShape = - higherTensorValue.getType().cast().getShape(); + cast(higherTensorValue.getType()).getShape(); ArrayRef lowerRankShape = - lowerTensorValue.getType().cast().getShape(); + cast(lowerTensorValue.getType()).getShape(); SmallVector reshapeOutputShape; @@ -116,7 +116,7 @@ .failed()) return rewriter.notifyMatchFailure(loc, "fail to compute a reshape type"); - auto reshapeInputType = lowerTensorValue.getType().cast(); + auto reshapeInputType = cast(lowerTensorValue.getType()); auto reshapeOutputType = RankedTensorType::get( ArrayRef(reshapeOutputShape), reshapeInputType.getElementType()); @@ -155,7 +155,7 @@ Value input2 = tosaBinaryOp.getInput2(); Value output = tosaBinaryOp.getResult(); - auto outputType = output.getType().dyn_cast(); + auto outputType = dyn_cast(output.getType()); if (!outputType) return failure(); @@ -183,7 +183,7 @@ Value input2 = tosaBinaryOp.getInput2(); int32_t shift = tosaBinaryOp.getShift(); Value output = tosaBinaryOp.getResult(); - auto outputType = output.getType().dyn_cast(); + auto outputType = dyn_cast(output.getType()); if (!outputType) return failure(); @@ -214,7 +214,7 @@ Value input2 = tosaBinaryOp.getInput2(); int32_t round = tosaBinaryOp.getRound(); Value output = tosaBinaryOp.getResult(); - auto outputType = output.getType().dyn_cast(); + auto outputType = dyn_cast(output.getType()); if (!outputType) return failure(); @@ -242,7 +242,7 @@ Value input3 = tosaOp.getOnFalse(); Value output = tosaOp.getResult(); - auto outputType = output.getType().dyn_cast(); + auto outputType = dyn_cast(output.getType()); if (!outputType) return rewriter.notifyMatchFailure(tosaOp, "output not a ranked tensor"); @@ -265,9 +265,9 @@ tosaOp, "cannot rewrite as the rank of all operands is already aligned"); - int32_t result1Rank = input1.getType().cast().getRank(); - int32_t result2Rank = input2.getType().cast().getRank(); - int32_t result3Rank = input3.getType().cast().getRank(); + int32_t result1Rank = cast(input1.getType()).getRank(); + int32_t result2Rank = cast(input2.getType()).getRank(); + int32_t result3Rank = cast(input3.getType()).getRank(); if ((result1Rank != result2Rank) || (result2Rank != result3Rank)) return rewriter.notifyMatchFailure( diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -106,7 +106,7 @@ getOperation().walk([&](Operation *op) { for (Value operand : op->getOperands()) { if ((profileType == TosaProfileEnum::BaseInference) && - getElementTypeOrSelf(operand).isa()) { + isa(getElementTypeOrSelf(operand))) { return signalPassFailure(); } if (getElementTypeOrSelf(operand).isF64()) { diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp --- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp @@ -116,16 +116,16 @@ mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight) { - auto inputType = input.getType().dyn_cast(); - auto weightType = weight.getType().dyn_cast(); + auto inputType = dyn_cast(input.getType()); + auto weightType = dyn_cast(weight.getType()); if (!inputType || !weightType) return nullptr; auto inputQType = GET_UQTYPE(inputType); auto weightPerTensorQType = GET_UQTYPE(weightType); - auto weightPerAxisQType = weightType.getElementType() - .dyn_cast(); + auto weightPerAxisQType = + dyn_cast(weightType.getElementType()); // Weights must be either per-tensor quantized or per-axis quantized. assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) && @@ -160,8 +160,8 @@ mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b) { - auto aType = a.getType().dyn_cast(); - auto bType = b.getType().dyn_cast(); + auto aType = dyn_cast(a.getType()); + auto bType = dyn_cast(b.getType()); if (!aType || !bType) return nullptr; @@ -189,8 +189,8 @@ mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType) { - auto inputType = input.getType().dyn_cast(); - auto outputType = outputRawType.dyn_cast(); + auto inputType = dyn_cast(input.getType()); + auto outputType = dyn_cast(outputRawType); if (!inputType || !outputType) return nullptr; @@ -215,7 +215,7 @@ PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder, Value input) { - auto inputType = input.getType().dyn_cast(); + auto inputType = dyn_cast(input.getType()); if (!inputType) return nullptr; @@ -235,8 +235,8 @@ Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight) { - auto inputType = input.getType().dyn_cast(); - auto weightType = weight.getType().dyn_cast(); + auto inputType = dyn_cast(input.getType()); + auto weightType = dyn_cast(weight.getType()); assert(inputType && weightType && "Could not extract input or weight tensors from Conv op"); @@ -250,7 +250,7 @@ unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); unsigned weightBits = weightQType.getStorageTypeIntegralWidth(); - auto outputShapedType = outputType.dyn_cast(); + auto outputShapedType = dyn_cast(outputType); assert(outputShapedType && "Could not extract output shape type from Conv op"); @@ -274,8 +274,8 @@ auto convfunc = quant::ExpressedToQuantizedConverter::forInputType(inputDType); - auto minElems = minAttr.dyn_cast(); - auto maxElems = maxAttr.dyn_cast(); + auto minElems = dyn_cast(minAttr); + auto maxElems = dyn_cast(maxAttr); SmallVector min, max; @@ -291,12 +291,12 @@ for (auto i : maxElems) max.push_back(FloatAttr::getValueAsDouble(i)); } else { // Just a single FP value. - auto minVal = minAttr.dyn_cast(); + auto minVal = dyn_cast(minAttr); if (minVal) min.push_back(minVal.getValueAsDouble()); else return {}; - auto maxVal = maxAttr.dyn_cast(); + auto maxVal = dyn_cast(maxAttr); if (maxVal) max.push_back(maxVal.getValueAsDouble()); else @@ -309,7 +309,7 @@ builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0], narrowRange.getValue(), convfunc.expressedType, isSigned); } else if (min.size() > 1) { // Per-axis quant on filterQuantDim. - auto shape = inputDType.dyn_cast(); + auto shape = dyn_cast(inputDType); if (!shape) return {}; if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) { diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -116,7 +116,7 @@ /// Returns the shape of the given type. Scalars will be considered as having a /// shape with zero dimensions. static ArrayRef getShape(Type type) { - if (auto sType = type.dyn_cast()) + if (auto sType = dyn_cast(type)) return sType.getShape(); return {}; } @@ -142,8 +142,8 @@ // If one of the types is unranked tensor, then the other type shouldn't be // vector and the result should have unranked tensor type. - if (type1.isa() || type2.isa()) { - if (type1.isa() || type2.isa()) + if (isa(type1) || isa(type2)) { + if (isa(type1) || isa(type2)) return {}; return UnrankedTensorType::get(elementType); } @@ -151,7 +151,7 @@ // Returns the type kind if the given type is a vector or ranked tensor type. // Returns std::nullopt otherwise. auto getCompositeTypeKind = [](Type type) -> std::optional { - if (type.isa()) + if (isa(type)) return type.getTypeID(); return std::nullopt; }; @@ -189,8 +189,8 @@ template static std::tuple hasTensorOrVectorType(iterator_range types) { return std::make_tuple( - llvm::any_of(types, [](Type t) { return t.isa(); }), - llvm::any_of(types, [](Type t) { return t.isa(); })); + llvm::any_of(types, [](Type t) { return isa(t); }), + llvm::any_of(types, [](Type t) { return isa(t); })); } static bool isCompatibleInferredReturnShape(ArrayRef inferred, @@ -242,7 +242,7 @@ return op->emitError("cannot broadcast vector with tensor"); auto rankedOperands = make_filter_range( - op->getOperandTypes(), [](Type t) { return t.isa(); }); + op->getOperandTypes(), [](Type t) { return isa(t); }); // If all operands are unranked, then all result shapes are possible. if (rankedOperands.empty()) @@ -261,7 +261,7 @@ } auto rankedResults = make_filter_range( - op->getResultTypes(), [](Type t) { return t.isa(); }); + op->getResultTypes(), [](Type t) { return isa(t); }); // If all of the results are unranked then no further verification. if (rankedResults.empty()) diff --git a/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp b/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp --- a/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp @@ -148,14 +148,14 @@ // TODO: when this ported to the dataflow analysis infra, we should have // proper support for region-based control flow. Operation *valueSource = - operand.get().isa() + isa(operand.get()) ? operand.get().getDefiningOp() : operand.get().getParentBlock()->getParentOp(); auto iface = cast(valueSource); SmallVector instances; iface.getEffectsOnResource(transform::TransformMappingResource::get(), instances); - assert((operand.get().isa() || + assert((isa(operand.get()) || hasEffect(instances, operand.get())) && "expected the op defining the value to have an allocation effect " "on it"); @@ -182,7 +182,7 @@ // value is defined in the middle of the block, i.e., is not a block // argument. bool isOutermost = ancestor == ancestors.front(); - bool isFromBlockPartial = isOutermost && operand.get().isa(); + bool isFromBlockPartial = isOutermost && isa(operand.get()); // Check if the value may be freed by operations between its definition // (allocation) point in its block and the terminator of the block or the diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -162,7 +162,7 @@ SmallVector reassociationAttr = llvm::to_vector<4>(llvm::map_range( reassociation, [&](const ReassociationIndices &indices) -> Attribute { - return b.getI64ArrayAttr(indices).cast(); + return cast(b.getI64ArrayAttr(indices)); })); return b.getArrayAttr(reassociationAttr); } @@ -267,7 +267,7 @@ } bool mlir::hasNonIdentityLayout(Type type) { - if (auto memrefType = type.dyn_cast()) + if (auto memrefType = dyn_cast(type)) return !memrefType.getLayout().isIdentity(); return false; } diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -19,7 +19,7 @@ if (!v) return false; if (auto attr = v.dyn_cast()) { - IntegerAttr intAttr = attr.dyn_cast(); + IntegerAttr intAttr = dyn_cast(attr); return intAttr && intAttr.getValue().isZero(); } if (auto cst = v.get().getDefiningOp()) @@ -53,7 +53,7 @@ SmallVectorImpl &staticVec) { auto v = ofr.dyn_cast(); if (!v) { - APInt apInt = ofr.get().cast().getValue(); + APInt apInt = cast(ofr.get()).getValue(); staticVec.push_back(apInt.getSExtValue()); return; } @@ -71,8 +71,8 @@ /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. SmallVector extractFromI64ArrayAttr(Attribute attr) { return llvm::to_vector<4>( - llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { - return a.cast().getInt(); + llvm::map_range(cast(attr), [](Attribute a) -> int64_t { + return cast(a).getInt(); })); } @@ -124,7 +124,7 @@ } // Case 2: Check for IntegerAttr. Attribute attr = ofr.dyn_cast(); - if (auto intAttr = attr.dyn_cast_or_null()) + if (auto intAttr = dyn_cast_or_null(attr)) return intAttr.getValue().getSExtValue(); return std::nullopt; } @@ -184,7 +184,7 @@ SmallVector dynamicValues; for (const auto &it : mixedValues) { if (it.is()) { - staticValues.push_back(it.get().cast().getInt()); + staticValues.push_back(cast(it.get()).getInt()); } else { staticValues.push_back(ShapedType::kDynamic); dynamicValues.push_back(it.get()); diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp @@ -21,9 +21,9 @@ if (indexingMaps.size() != 3) return false; - auto map0 = indexingMaps[0].cast().getValue(); - auto map1 = indexingMaps[1].cast().getValue(); - auto map2 = indexingMaps[2].cast().getValue(); + auto map0 = cast(indexingMaps[0]).getValue(); + auto map1 = cast(indexingMaps[1]).getValue(); + auto map2 = cast(indexingMaps[2]).getValue(); if (map0.getNumResults() != 2 || map1.getNumResults() != 2 || map2.getNumResults() != 2 || map0.getNumInputs() != 3 || @@ -47,9 +47,9 @@ if (indexingMaps.size() != 3) return false; - auto map0 = indexingMaps[0].cast().getValue(); - auto map1 = indexingMaps[1].cast().getValue(); - auto map2 = indexingMaps[2].cast().getValue(); + auto map0 = cast(indexingMaps[0]).getValue(); + auto map1 = cast(indexingMaps[1]).getValue(); + auto map2 = cast(indexingMaps[2]).getValue(); if (map0.getNumResults() != 2 || map1.getNumResults() != 2 || map2.getNumResults() != 2 || map0.getNumInputs() != 3 || @@ -73,9 +73,9 @@ if (indexingMaps.size() != 3) return false; - auto map0 = indexingMaps[0].cast().getValue(); - auto map1 = indexingMaps[1].cast().getValue(); - auto map2 = indexingMaps[2].cast().getValue(); + auto map0 = cast(indexingMaps[0]).getValue(); + auto map1 = cast(indexingMaps[1]).getValue(); + auto map2 = cast(indexingMaps[2]).getValue(); if (map0.getNumResults() != 3 || map1.getNumResults() != 3 || map2.getNumResults() != 3 || map0.getNumInputs() != 4 || diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -30,14 +30,14 @@ vector::TransferReadOp> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - assert(opOperand.get().getType().isa() && + assert(isa(opOperand.get().getType()) && "only tensor types expected"); return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - assert(opOperand.get().getType().isa() && + assert(isa(opOperand.get().getType()) && "only tensor types expected"); return false; } @@ -50,7 +50,7 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto readOp = cast(op); - assert(readOp.getShapedType().isa() && + assert(isa(readOp.getShapedType()) && "only tensor types expected"); FailureOr buffer = getBuffer(rewriter, readOp.getSource(), options); if (failed(buffer)) @@ -74,7 +74,7 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto writeOp = cast(op); - assert(writeOp.getShapedType().isa() && + assert(isa(writeOp.getShapedType()) && "only tensor types expected"); // Create a new transfer_write on buffer that doesn't have a return value. @@ -99,14 +99,14 @@ vector::GatherOp> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - assert(opOperand.get().getType().isa() && + assert(isa(opOperand.get().getType()) && "only tensor types expected"); return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - assert(opOperand.get().getType().isa() && + assert(isa(opOperand.get().getType()) && "only tensor types expected"); return false; } @@ -119,7 +119,7 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto gatherOp = cast(op); - assert(gatherOp.getBaseType().isa() && + assert(isa(gatherOp.getBaseType()) && "only tensor types expected"); FailureOr buffer = getBuffer(rewriter, gatherOp.getBase(), options); if (failed(buffer)) @@ -266,7 +266,7 @@ // may get dropped during the bufferization of vector.mask. SmallVector newResults; for (Value value : yieldOp.getOperands()) { - if (value.getType().isa()) { + if (isa(value.getType())) { FailureOr maybeBuffer = getBuffer(rewriter, value, options); if (failed(maybeBuffer)) return failure(); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp @@ -49,7 +49,7 @@ PatternRewriter &rewriter) const override { auto loc = op.getLoc(); VectorType dstType = op.getResultVectorType(); - VectorType srcType = op.getSourceType().dyn_cast(); + VectorType srcType = dyn_cast(op.getSourceType()); Type eltType = dstType.getElementType(); // Scalar to any vector can use splat. diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -96,9 +96,9 @@ return rewriter.create(loc, lowType, val, posAttr); } // Unroll leading dimensions. - VectorType vType = lowType.cast(); + VectorType vType = cast(lowType); Type resType = VectorType::Builder(type).dropDim(index); - auto resVectorType = resType.cast(); + auto resVectorType = cast(resType); Value result = rewriter.create( loc, resVectorType, rewriter.getZeroAttr(resVectorType)); for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { @@ -126,7 +126,7 @@ } // Unroll leading dimensions. Type lowType = VectorType::Builder(type).dropDim(0); - VectorType vType = lowType.cast(); + VectorType vType = cast(lowType); Type insType = VectorType::Builder(vType).dropDim(0); for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { auto posAttr = rewriter.getI64ArrayAttr(d); @@ -160,7 +160,7 @@ // Only valid for integer types. return std::nullopt; // Special case for fused multiply-add. - if (acc && acc.getType().isa() && kind == CombiningKind::ADD) { + if (acc && isa(acc.getType()) && kind == CombiningKind::ADD) { Value fma = rewriter.create(loc, x, y, acc); if (mask) // The fma op doesn't need explicit masking. However, fma ops used in @@ -418,7 +418,7 @@ Value promote(Value v, Type dstElementType) { Type elementType = v.getType(); - auto vecType = elementType.dyn_cast(); + auto vecType = dyn_cast(elementType); if (vecType) elementType = vecType.getElementType(); if (elementType == dstElementType) @@ -426,7 +426,7 @@ Type promotedType = dstElementType; if (vecType) promotedType = VectorType::get(vecType.getShape(), promotedType); - if (dstElementType.isa()) + if (isa(dstElementType)) return rewriter.create(loc, promotedType, v); return rewriter.create(loc, promotedType, v); } @@ -438,7 +438,7 @@ if (mask && !maybeMask.has_value()) return failure(); - Type resElementType = res.getType().cast().getElementType(); + Type resElementType = cast(res.getType()).getElementType(); for (int64_t k = 0; k < reductionSize; ++k) { Value extractA = rewriter.create(loc, lhs, k); Value extractB = rewriter.create(loc, rhs, k); @@ -684,7 +684,7 @@ return failure(); } - VectorType dstType = op.getResultType().cast(); + VectorType dstType = cast(op.getResultType()); assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && "Expected dst type of rank 1 or 2"); @@ -695,7 +695,7 @@ // ExtractOp does not allow dynamic indexing, we must unroll explicitly. Value res = rewriter.create(loc, dstType, rewriter.getZeroAttr(dstType)); - bool isInt = dstType.getElementType().isa(); + bool isInt = isa(dstType.getElementType()); for (unsigned r = 0; r < dstRows; ++r) { Value a = rewriter.create(op.getLoc(), lhs, r); for (unsigned c = 0; c < dstColumns; ++c) { @@ -789,7 +789,7 @@ } else { // If the parallel dimension doesn't exist we will have to broadcast it. lhsDims.push_back( - contractOp.getResultType().cast().getDimSize(i)); + cast(contractOp.getResultType()).getDimSize(i)); lhsTranspose.push_back(lhsDims.size() - 1); } std::optional rhsDim = @@ -799,7 +799,7 @@ } else { // If the parallel dimension doesn't exist we will have to broadcast it. rhsDims.push_back( - contractOp.getResultType().cast().getDimSize(i)); + cast(contractOp.getResultType()).getDimSize(i)); rhsTranspose.push_back(rhsDims.size() - 1); } } @@ -969,7 +969,7 @@ Value mask) const { VectorType lhsType = op.getLhsType(); VectorType rhsType = op.getRhsType(); - VectorType resType = op.getResultType().cast(); + VectorType resType = cast(op.getResultType()); // Find the iterator type index and result index. SmallVector iMap = op.getIndexingMapsArray(); int64_t iterIndex = -1; @@ -1044,10 +1044,10 @@ VectorType lhsType = op.getLhsType(); VectorType rhsType = op.getRhsType(); Type resType = op.getResultType(); - if (resType.isa()) + if (isa(resType)) return rewriter.notifyMatchFailure(op, "did not expect a VectorType result"); - bool isInt = resType.isa(); + bool isInt = isa(resType); // Use iterator index 0. int64_t iterIndex = 0; SmallVector iMap = op.getIndexingMapsArray(); @@ -1133,10 +1133,10 @@ auto loc = op.getLoc(); VectorType lhsType = op.getOperandVectorTypeLHS(); - VectorType rhsType = op.getOperandTypeRHS().dyn_cast(); + VectorType rhsType = dyn_cast(op.getOperandTypeRHS()); VectorType resType = op.getResultVectorType(); Type eltType = resType.getElementType(); - bool isInt = eltType.isa(); + bool isInt = isa(eltType); Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0]; vector::CombiningKind kind = op.getKind(); @@ -1231,7 +1231,7 @@ return failure(); Type dstElementType = op.getType(); - if (auto vecType = dstElementType.dyn_cast()) + if (auto vecType = dyn_cast(dstElementType)) dstElementType = vecType.getElementType(); if (elementType != dstElementType) return failure(); @@ -1259,8 +1259,8 @@ return failure(); // At this point lhs and rhs are in row-major. - VectorType lhsType = lhs.getType().cast(); - VectorType rhsType = rhs.getType().cast(); + VectorType lhsType = cast(lhs.getType()); + VectorType rhsType = cast(rhs.getType()); int64_t lhsRows = lhsType.getDimSize(0); int64_t lhsColumns = lhsType.getDimSize(1); int64_t rhsColumns = rhsType.getDimSize(1); @@ -1289,7 +1289,7 @@ llvm_unreachable("invalid contraction semantics"); Value res = - elementType.isa() + isa(elementType) ? static_cast(rew.create(loc, op.getAcc(), mul)) : static_cast( rew.create(loc, op.getAcc(), mul)); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -52,7 +52,7 @@ LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override { - auto dstType = op.getResult().getType().cast(); + auto dstType = cast(op.getResult().getType()); int64_t rank = dstType.getRank(); if (rank <= 1) return rewriter.notifyMatchFailure( @@ -112,7 +112,7 @@ if (rank == 0) { assert(dimSizes.size() == 1 && "Expected exactly one dim size for a 0-D vector"); - bool value = dimSizes[0].cast().getInt() == 1; + bool value = cast(dimSizes[0]).getInt() == 1; rewriter.replaceOpWithNewOp( op, dstType, DenseIntElementsAttr::get( @@ -122,14 +122,14 @@ } // Scalable constant masks can only be lowered for the "none set" case. - if (dstType.cast().isScalable()) { + if (cast(dstType).isScalable()) { rewriter.replaceOpWithNewOp( op, DenseElementsAttr::get(dstType, false)); return success(); } int64_t trueDim = std::min(dstType.getDimSize(0), - dimSizes[0].cast().getInt()); + cast(dimSizes[0]).getInt()); if (rank == 1) { // Express constant 1-D case in explicit vector form: @@ -146,7 +146,7 @@ VectorType::get(dstType.getShape().drop_front(), eltType); SmallVector newDimSizes; for (int64_t r = 1; r < rank; r++) - newDimSizes.push_back(dimSizes[r].cast().getInt()); + newDimSizes.push_back(cast(dimSizes[r]).getInt()); Value trueVal = rewriter.create( loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); Value result = rewriter.create( diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp @@ -48,7 +48,7 @@ PatternRewriter &rewriter) { using vector::CombiningKind; - auto elType = x.getType().cast().getElementType(); + auto elType = cast(x.getType()).getElementType(); bool isInt = elType.isIntOrIndex(); Value combinedResult{nullptr}; diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -29,7 +29,7 @@ size_t index = 0; for (unsigned pos : permutation) newInBoundsValues[pos] = - attr.getValue()[index++].cast().getValue(); + cast(attr.getValue()[index++]).getValue(); return builder.getBoolArrayAttr(newInBoundsValues); } @@ -37,7 +37,7 @@ /// dimensions. static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank) { - auto originalVecType = vec.getType().cast(); + auto originalVecType = cast(vec.getType()); SmallVector newShape(addedRank, 1); newShape.append(originalVecType.getShape().begin(), originalVecType.getShape().end()); @@ -257,7 +257,7 @@ // All the new dimensions added are inbound. SmallVector newInBoundsValues(missingInnerDim.size(), true); for (Attribute attr : op.getInBounds().value().getValue()) { - newInBoundsValues.push_back(attr.cast().getValue()); + newInBoundsValues.push_back(cast(attr).getValue()); } newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues); } @@ -315,7 +315,7 @@ // In the meantime, lower these to a scalar load when they pop up. if (reducedShapeRank == 0) { Value newRead; - if (op.getShapedType().isa()) { + if (isa(op.getShapedType())) { newRead = rewriter.create( op.getLoc(), op.getSource(), op.getIndices()); } else { @@ -397,7 +397,7 @@ &broadcastedDims)) return rewriter.notifyMatchFailure(read, "not minor identity + bcast"); - auto memRefType = read.getShapedType().dyn_cast(); + auto memRefType = dyn_cast(read.getShapedType()); if (!memRefType) return rewriter.notifyMatchFailure(read, "not a memref source"); @@ -418,11 +418,11 @@ // `vector.load` supports vector types as memref's elements only when the // resulting vector type is the same as the element type. auto memrefElTy = memRefType.getElementType(); - if (memrefElTy.isa() && memrefElTy != unbroadcastedVectorType) + if (isa(memrefElTy) && memrefElTy != unbroadcastedVectorType) return rewriter.notifyMatchFailure(read, "incompatible element type"); // Otherwise, element types of the memref and the vector must match. - if (!memrefElTy.isa() && + if (!isa(memrefElTy) && memrefElTy != read.getVectorType().getElementType()) return rewriter.notifyMatchFailure(read, "non-matching element type"); @@ -543,7 +543,7 @@ diag << "permutation map is not minor identity: " << write; }); - auto memRefType = write.getShapedType().dyn_cast(); + auto memRefType = dyn_cast(write.getShapedType()); if (!memRefType) return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { diag << "not a memref type: " << write; @@ -558,13 +558,13 @@ // `vector.store` supports vector types as memref's elements only when the // type of the vector value being written is the same as the element type. auto memrefElTy = memRefType.getElementType(); - if (memrefElTy.isa() && memrefElTy != write.getVectorType()) + if (isa(memrefElTy) && memrefElTy != write.getVectorType()) return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { diag << "elemental type mismatch: " << write; }); // Otherwise, element types of the memref and the vector must match. - if (!memrefElTy.isa() && + if (!isa(memrefElTy) && memrefElTy != write.getVectorType().getElementType()) return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { diag << "elemental type mismatch: " << write; diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -156,7 +156,7 @@ /// dst[511:384] := SELECT4(v2[511:0], mask[7:6]) static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { - assert(v1.getType().cast().getShape()[0] == 16 && + assert(cast(v1.getType()).getShape()[0] == 16 && "expected a vector with length=16"); SmallVector shuffleMask; auto appendToMask = [&](int64_t base, uint8_t control) { @@ -291,7 +291,7 @@ vs[0xf] = create4x128BitSuffle(b, t7, tf, 0xdd); auto reshInputType = VectorType::get( - {m, n}, source.getType().cast().getElementType()); + {m, n}, cast(source.getType()).getElementType()); Value res = b.create(reshInputType, b.getZeroAttr(reshInputType)); for (int64_t i = 0; i < m; ++i) @@ -329,7 +329,7 @@ // Set up convenience transposition table. SmallVector transp; for (auto attr : op.getTransp()) - transp.push_back(attr.cast().getInt()); + transp.push_back(cast(attr).getInt()); if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) && resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) @@ -419,7 +419,7 @@ SmallVector transp; for (auto attr : op.getTransp()) - transp.push_back(attr.cast().getInt()); + transp.push_back(cast(attr).getInt()); if (transp[0] != 1 && transp[1] != 0) return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation"); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -62,8 +62,8 @@ Value laneId, Value zero) : sequentialVal(sequentialVal), distributedVal(distributedVal), laneId(laneId), zero(zero) { - sequentialVectorType = sequentialVal.getType().dyn_cast(); - distributedVectorType = distributedVal.getType().dyn_cast(); + sequentialVectorType = dyn_cast(sequentialVal.getType()); + distributedVectorType = dyn_cast(distributedVal.getType()); if (sequentialVectorType && distributedVectorType) distributionMap = calculateImplicitMap(sequentialVectorType, distributedVectorType); @@ -89,7 +89,7 @@ "Must store either the preregistered distributed or the " "preregistered sequential value."); // Scalar case can directly use memref.store. - if (!val.getType().isa()) + if (!isa(val.getType())) return b.create(loc, val, buffer, zero); // Vector case must use vector::TransferWriteOp which will later lower to @@ -131,7 +131,7 @@ Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) { // Scalar case can directly use memref.store. - if (!type.isa()) + if (!isa(type)) return b.create(loc, buffer, zero); // Other cases must be vector atm. @@ -149,7 +149,7 @@ } SmallVector inBounds(indices.size(), true); return b.create( - loc, type.cast(), buffer, indices, + loc, cast(type), buffer, indices, ArrayRef(inBounds.begin(), inBounds.end())); } @@ -630,14 +630,14 @@ Location loc = warpOp.getLoc(); for (OpOperand &operand : elementWise->getOpOperands()) { Type targetType; - if (auto vecType = distributedVal.getType().dyn_cast()) { + if (auto vecType = dyn_cast(distributedVal.getType())) { // If the result type is a vector, the operands must also be vectors. - auto operandType = operand.get().getType().cast(); + auto operandType = cast(operand.get().getType()); targetType = VectorType::get(vecType.getShape(), operandType.getElementType()); } else { auto operandType = operand.get().getType(); - assert(!operandType.isa() && + assert(!isa(operandType) && "unexpected yield of vector from op with scalar result type"); targetType = operandType; } @@ -687,7 +687,7 @@ if (!yieldOperand) return failure(); auto constantOp = yieldOperand->get().getDefiningOp(); - auto dense = constantOp.getValue().dyn_cast(); + auto dense = dyn_cast(constantOp.getValue()); if (!dense) return failure(); unsigned operandIndex = yieldOperand->getOperandNumber(); @@ -737,8 +737,8 @@ SmallVector indices(read.getIndices().begin(), read.getIndices().end()); - auto sequentialType = read.getResult().getType().cast(); - auto distributedType = distributedVal.getType().cast(); + auto sequentialType = cast(read.getResult().getType()); + auto distributedType = cast(distributedVal.getType()); AffineMap map = calculateImplicitMap(sequentialType, distributedType); AffineMap indexMap = map.compose(read.getPermutationMap()); OpBuilder::InsertionGuard g(rewriter); @@ -752,7 +752,7 @@ unsigned indexPos = indexExpr.getPosition(); unsigned vectorPos = std::get<1>(it).cast().getPosition(); int64_t scale = - distributedVal.getType().cast().getDimSize(vectorPos); + cast(distributedVal.getType()).getDimSize(vectorPos); indices[indexPos] = affine::makeComposedAffineApply( rewriter, read.getLoc(), d0 + scale * d1, {indices[indexPos], warpOp.getLaneid()}); @@ -845,7 +845,7 @@ resultIndex = operand.getOperandNumber(); break; } - auto arg = operand.get().dyn_cast(); + auto arg = dyn_cast(operand.get()); if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation()) continue; Value warpOperand = warpOp.getArgs()[arg.getArgNumber()]; @@ -874,7 +874,7 @@ auto broadcastOp = operand->get().getDefiningOp(); Location loc = broadcastOp.getLoc(); auto destVecType = - warpOp->getResultTypes()[operandNumber].cast(); + cast(warpOp->getResultTypes()[operandNumber]); SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( rewriter, warpOp, {broadcastOp.getSource()}, @@ -914,7 +914,7 @@ // Rewrite vector.extract with 1d source to vector.extractelement. if (extractSrcType.getRank() == 1) { assert(extractOp.getPosition().size() == 1 && "expected 1 index"); - int64_t pos = extractOp.getPosition()[0].cast().getInt(); + int64_t pos = cast(extractOp.getPosition()[0]).getInt(); rewriter.setInsertionPoint(extractOp); rewriter.replaceOpWithNewOp( extractOp, extractOp.getVector(), @@ -946,8 +946,8 @@ // Find the distributed dimension. There should be exactly one. auto distributedType = - warpOp.getResult(operandNumber).getType().cast(); - auto yieldedType = operand->get().getType().cast(); + cast(warpOp.getResult(operandNumber).getType()); + auto yieldedType = cast(operand->get().getType()); int64_t distributedDim = -1; for (int64_t i = 0; i < yieldedType.getRank(); ++i) { if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) { @@ -1083,7 +1083,7 @@ auto insertOp = operand->get().getDefiningOp(); VectorType vecType = insertOp.getDestVectorType(); VectorType distrType = - warpOp.getResult(operandNumber).getType().cast(); + cast(warpOp.getResult(operandNumber).getType()); bool hasPos = static_cast(insertOp.getPosition()); // Yield destination vector, source scalar and position from warp op. @@ -1171,7 +1171,7 @@ // Rewrite vector.insert with 1d dest to vector.insertelement. if (insertOp.getDestVectorType().getRank() == 1) { assert(insertOp.getPosition().size() == 1 && "expected 1 index"); - int64_t pos = insertOp.getPosition()[0].cast().getInt(); + int64_t pos = cast(insertOp.getPosition()[0]).getInt(); rewriter.setInsertionPoint(insertOp); rewriter.replaceOpWithNewOp( insertOp, insertOp.getSource(), insertOp.getDest(), @@ -1199,8 +1199,8 @@ // Find the distributed dimension. There should be exactly one. auto distrDestType = - warpOp.getResult(operandNumber).getType().cast(); - auto yieldedType = operand->get().getType().cast(); + cast(warpOp.getResult(operandNumber).getType()); + auto yieldedType = cast(operand->get().getType()); int64_t distrDestDim = -1; for (int64_t i = 0; i < yieldedType.getRank(); ++i) { if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) { @@ -1213,7 +1213,7 @@ assert(distrDestDim != -1 && "could not find distributed dimension"); // Compute the distributed source vector type. - VectorType srcVecType = insertOp.getSourceType().cast(); + VectorType srcVecType = cast(insertOp.getSourceType()); SmallVector distrSrcShape(srcVecType.getShape().begin(), srcVecType.getShape().end()); // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32> @@ -1248,7 +1248,7 @@ int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim); SmallVector newPos = llvm::to_vector( llvm::map_range(insertOp.getPosition(), [](Attribute attr) { - return attr.cast().getInt(); + return cast(attr).getInt(); })); // tid of inserting lane: pos / elementsPerLane Value insertingLane = rewriter.create( @@ -1337,7 +1337,7 @@ if (!escapingValues.insert(operand->get())) return; Type distType = operand->get().getType(); - if (auto vecType = distType.cast()) { + if (auto vecType = cast(distType)) { AffineMap map = distributionMapFn(operand->get()); distType = getDistributedType(vecType, map, warpOp.getWarpSize()); } @@ -1359,7 +1359,7 @@ for (OpOperand &yieldOperand : yield->getOpOperands()) { if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) continue; - auto forResult = yieldOperand.get().cast(); + auto forResult = cast(yieldOperand.get()); newOperands.push_back( newWarpOp.getResult(yieldOperand.getOperandNumber())); yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]); @@ -1463,7 +1463,7 @@ auto reductionOp = cast(yieldOperand->get().getDefiningOp()); - auto vectorType = reductionOp.getVector().getType().cast(); + auto vectorType = cast(reductionOp.getVector().getType()); // Only rank 1 vectors supported. if (vectorType.getRank() != 1) return rewriter.notifyMatchFailure( @@ -1564,7 +1564,7 @@ // operations from there. for (auto &op : body->without_terminator()) { bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) { - return result.getType().isa(); + return isa(result.getType()); }); if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody)) opsToMove.insert(&op); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -136,10 +136,10 @@ Type oldSrcType = insertOp.getSourceType(); Type newSrcType = oldSrcType; int64_t oldSrcRank = 0, newSrcRank = 0; - if (auto type = oldSrcType.dyn_cast()) { + if (auto type = dyn_cast(oldSrcType)) { newSrcType = trimLeadingOneDims(type); oldSrcRank = type.getRank(); - newSrcRank = newSrcType.cast().getRank(); + newSrcRank = cast(newSrcType).getRank(); } VectorType oldDstType = insertOp.getDestVectorType(); @@ -199,7 +199,7 @@ if (read.getMask()) return failure(); - auto shapedType = read.getSource().getType().cast(); + auto shapedType = cast(read.getSource().getType()); if (shapedType.getElementType() != read.getVectorType().getElementType()) return failure(); @@ -247,7 +247,7 @@ if (write.getMask()) return failure(); - auto shapedType = write.getSource().getType().dyn_cast(); + auto shapedType = dyn_cast(write.getSource().getType()); if (shapedType.getElementType() != write.getVectorType().getElementType()) return failure(); @@ -284,7 +284,7 @@ LogicalResult mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, RewriterBase &rewriter) { - VectorType oldAccType = contractOp.getAccType().dyn_cast(); + VectorType oldAccType = dyn_cast(contractOp.getAccType()); if (oldAccType == nullptr) return failure(); if (oldAccType.getRank() < 2) @@ -418,7 +418,7 @@ PatternRewriter &rewriter) const override { if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) return failure(); - auto vecType = op->getResultTypes()[0].dyn_cast(); + auto vecType = dyn_cast(op->getResultTypes()[0]); if (!vecType) return failure(); VectorType newVecType = trimLeadingOneDims(vecType); @@ -427,7 +427,7 @@ int64_t dropDim = vecType.getRank() - newVecType.getRank(); SmallVector newOperands; for (Value operand : op->getOperands()) { - if (auto opVecType = operand.getType().dyn_cast()) { + if (auto opVecType = dyn_cast(operand.getType())) { newOperands.push_back(rewriter.create( op->getLoc(), operand, splatZero(dropDim))); } else { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -21,7 +21,7 @@ // Helper that picks the proper sequence for inserting. static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, Value into, int64_t offset) { - auto vectorType = into.getType().cast(); + auto vectorType = cast(into.getType()); if (vectorType.getRank() > 1) return rewriter.create(loc, from, into, offset); return rewriter.create( @@ -32,7 +32,7 @@ // Helper that picks the proper sequence for extracting. static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, int64_t offset) { - auto vectorType = vector.getType().cast(); + auto vectorType = cast(vector.getType()); if (vectorType.getRank() > 1) return rewriter.create(loc, vector, offset); return rewriter.create( @@ -134,10 +134,10 @@ } int64_t offset = - op.getOffsets().getValue().front().cast().getInt(); + cast(op.getOffsets().getValue().front()).getInt(); int64_t size = srcType.getShape().front(); int64_t stride = - op.getStrides().getValue().front().cast().getInt(); + cast(op.getStrides().getValue().front()).getInt(); auto loc = op.getLoc(); Value res = op.getDest(); @@ -174,7 +174,7 @@ off += stride, ++idx) { // 1. extract the proper subvector (or element) from source Value extractedSource = extractOne(rewriter, loc, op.getSource(), idx); - if (extractedSource.getType().isa()) { + if (isa(extractedSource.getType())) { // 2. If we have a vector, extract the proper subvector from destination // Otherwise we are at the element level and no need to recurse. Value extractedDest = extractOne(rewriter, loc, op.getDest(), off); @@ -208,11 +208,10 @@ assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets"); int64_t offset = - op.getOffsets().getValue().front().cast().getInt(); - int64_t size = - op.getSizes().getValue().front().cast().getInt(); + cast(op.getOffsets().getValue().front()).getInt(); + int64_t size = cast(op.getSizes().getValue().front()).getInt(); int64_t stride = - op.getStrides().getValue().front().cast().getInt(); + cast(op.getStrides().getValue().front()).getInt(); assert(dstType.getElementType().isSignlessIntOrIndexOrFloat()); @@ -254,11 +253,10 @@ return failure(); int64_t offset = - op.getOffsets().getValue().front().cast().getInt(); - int64_t size = - op.getSizes().getValue().front().cast().getInt(); + cast(op.getOffsets().getValue().front()).getInt(); + int64_t size = cast(op.getSizes().getValue().front()).getInt(); int64_t stride = - op.getStrides().getValue().front().cast().getInt(); + cast(op.getStrides().getValue().front()).getInt(); Location loc = op.getLoc(); SmallVector elements; @@ -300,11 +298,10 @@ assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets"); int64_t offset = - op.getOffsets().getValue().front().cast().getInt(); - int64_t size = - op.getSizes().getValue().front().cast().getInt(); + cast(op.getOffsets().getValue().front()).getInt(); + int64_t size = cast(op.getSizes().getValue().front()).getInt(); int64_t stride = - op.getStrides().getValue().front().cast().getInt(); + cast(op.getStrides().getValue().front()).getInt(); auto loc = op.getLoc(); auto elemType = dstType.getElementType(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -261,7 +261,7 @@ llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; })); Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( targetShape, inputType, offsets, sizes, strides); - return canonicalizeStridedLayout(rankReducedType.cast()); + return canonicalizeStridedLayout(cast(rankReducedType)); } /// Creates a rank-reducing memref.subview op that drops unit dims from its @@ -269,7 +269,7 @@ static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, mlir::Location loc, Value input) { - MemRefType inputType = input.getType().cast(); + MemRefType inputType = cast(input.getType()); assert(inputType.hasStaticShape()); SmallVector subViewOffsets(inputType.getRank(), 0); SmallVector subViewStrides(inputType.getRank(), 1); @@ -304,9 +304,9 @@ PatternRewriter &rewriter) const override { auto loc = transferReadOp.getLoc(); Value vector = transferReadOp.getVector(); - VectorType vectorType = vector.getType().cast(); + VectorType vectorType = cast(vector.getType()); Value source = transferReadOp.getSource(); - MemRefType sourceType = source.getType().dyn_cast(); + MemRefType sourceType = dyn_cast(source.getType()); // TODO: support tensor types. if (!sourceType || !sourceType.hasStaticShape()) return failure(); @@ -347,9 +347,9 @@ PatternRewriter &rewriter) const override { auto loc = transferWriteOp.getLoc(); Value vector = transferWriteOp.getVector(); - VectorType vectorType = vector.getType().cast(); + VectorType vectorType = cast(vector.getType()); Value source = transferWriteOp.getSource(); - MemRefType sourceType = source.getType().dyn_cast(); + MemRefType sourceType = dyn_cast(source.getType()); // TODO: support tensor type. if (!sourceType || !sourceType.hasStaticShape()) return failure(); @@ -406,7 +406,7 @@ /// input starting at `firstDimToCollapse`. static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, Value input, int64_t firstDimToCollapse) { - ShapedType inputType = input.getType().cast(); + ShapedType inputType = cast(input.getType()); if (inputType.getRank() == 1) return input; SmallVector reassociation; @@ -451,9 +451,9 @@ PatternRewriter &rewriter) const override { auto loc = transferReadOp.getLoc(); Value vector = transferReadOp.getVector(); - VectorType vectorType = vector.getType().cast(); + VectorType vectorType = cast(vector.getType()); Value source = transferReadOp.getSource(); - MemRefType sourceType = source.getType().dyn_cast(); + MemRefType sourceType = dyn_cast(source.getType()); // Contiguity check is valid on tensors only. if (!sourceType) return failure(); @@ -481,7 +481,7 @@ Value collapsedSource = collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); MemRefType collapsedSourceType = - collapsedSource.getType().dyn_cast(); + dyn_cast(collapsedSource.getType()); int64_t collapsedRank = collapsedSourceType.getRank(); assert(collapsedRank == firstContiguousInnerDim + 1); SmallVector dimExprs{ @@ -494,7 +494,7 @@ loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap); flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); rewriter.replaceOpWithNewOp( - transferReadOp, vector.getType().cast(), flatRead); + transferReadOp, cast(vector.getType()), flatRead); return success(); } }; @@ -511,9 +511,9 @@ PatternRewriter &rewriter) const override { auto loc = transferWriteOp.getLoc(); Value vector = transferWriteOp.getVector(); - VectorType vectorType = vector.getType().cast(); + VectorType vectorType = cast(vector.getType()); Value source = transferWriteOp.getSource(); - MemRefType sourceType = source.getType().dyn_cast(); + MemRefType sourceType = dyn_cast(source.getType()); // Contiguity check is valid on tensors only. if (!sourceType) return failure(); @@ -541,7 +541,7 @@ Value collapsedSource = collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); MemRefType collapsedSourceType = - collapsedSource.getType().cast(); + cast(collapsedSource.getType()); int64_t collapsedRank = collapsedSourceType.getRank(); assert(collapsedRank == firstContiguousInnerDim + 1); SmallVector dimExprs{ @@ -610,7 +610,7 @@ *getConstantIntValue(ofr)); } } - if (xferOp.getSource().getType().isa()) { + if (isa(xferOp.getSource().getType())) { rewriter.replaceOpWithNewOp(extractOp, xferOp.getSource(), newIndices); } else { @@ -637,7 +637,7 @@ LogicalResult matchAndRewrite(vector::ExtractOp extractOp, PatternRewriter &rewriter) const override { // Only match scalar extracts. - if (extractOp.getType().isa()) + if (isa(extractOp.getType())) return failure(); auto xferOp = extractOp.getVector().getDefiningOp(); if (!xferOp) @@ -660,7 +660,7 @@ SmallVector newIndices(xferOp.getIndices().begin(), xferOp.getIndices().end()); for (const auto &it : llvm::enumerate(extractOp.getPosition())) { - int64_t offset = it.value().cast().getInt(); + int64_t offset = cast(it.value()).getInt(); int64_t idx = newIndices.size() - extractOp.getPosition().size() + it.index(); OpFoldResult ofr = affine::makeComposedFoldedAffineApply( @@ -673,7 +673,7 @@ extractOp.getLoc(), *getConstantIntValue(ofr)); } } - if (xferOp.getSource().getType().isa()) { + if (isa(xferOp.getSource().getType())) { rewriter.replaceOpWithNewOp(extractOp, xferOp.getSource(), newIndices); } else { @@ -714,7 +714,7 @@ xferOp.getVector(), pos); } // Construct a scalar store. - if (xferOp.getSource().getType().isa()) { + if (isa(xferOp.getSource().getType())) { rewriter.replaceOpWithNewOp( xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); } else { @@ -732,12 +732,12 @@ // Run store to load forwarding first since it can expose more dead store // opportunity. rootOp->walk([&](vector::TransferReadOp read) { - if (read.getShapedType().isa()) + if (isa(read.getShapedType())) opt.storeToLoadForwarding(read); }); opt.removeDeadOp(); rootOp->walk([&](vector::TransferWriteOp write) { - if (write.getShapedType().isa()) + if (isa(write.getShapedType())) opt.deadStoreOp(write); }); opt.removeDeadOp(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -190,7 +190,7 @@ Location loc = xferOp.getLoc(); int64_t memrefRank = xferOp.getShapedType().getRank(); // TODO: relax this precondition, will require rank-reducing subviews. - assert(memrefRank == alloc.getType().cast().getRank() && + assert(memrefRank == cast(alloc.getType()).getRank() && "Expected memref rank to match the alloc rank"); ValueRange leadingIndices = xferOp.indices().take_front(xferOp.getLeadingShapedRank()); @@ -571,8 +571,8 @@ } MemRefType compatibleMemRefType = - getCastCompatibleMemRefType(xferOp.getShapedType().cast(), - alloc.getType().cast()); + getCastCompatibleMemRefType(cast(xferOp.getShapedType()), + cast(alloc.getType())); if (!compatibleMemRefType) return failure(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -93,9 +93,9 @@ PatternRewriter &rewriter) const override { // Check if 'shapeCastOp' has vector source/result type. auto sourceVectorType = - shapeCastOp.getSource().getType().dyn_cast_or_null(); + dyn_cast_or_null(shapeCastOp.getSource().getType()); auto resultVectorType = - shapeCastOp.getResult().getType().dyn_cast_or_null(); + dyn_cast_or_null(shapeCastOp.getResult().getType()); if (!sourceVectorType || !resultVectorType) return failure(); @@ -105,7 +105,7 @@ if (!sourceShapeCastOp) return failure(); auto operandSourceVectorType = - sourceShapeCastOp.getSource().getType().cast(); + cast(sourceShapeCastOp.getSource().getType()); auto operandResultVectorType = sourceShapeCastOp.getType(); // Check if shape cast operations invert each other. @@ -342,7 +342,7 @@ if (!broadcast) continue; // contractionOp can only take vector as operands. - auto srcType = broadcast.getSourceType().dyn_cast(); + auto srcType = dyn_cast(broadcast.getSourceType()); if (!srcType || srcType.getRank() == broadcast.getResultVectorType().getRank()) continue; @@ -455,7 +455,7 @@ return failure(); Type castResTy = getElementTypeOrSelf(op->getResult(0)); - if (auto vecTy = bcastOp.getSourceType().dyn_cast()) + if (auto vecTy = dyn_cast(bcastOp.getSourceType())) castResTy = VectorType::get(vecTy.getShape(), castResTy); auto *castOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), @@ -530,7 +530,7 @@ // This is a constant. Create a reverse transpose op for it. auto vectorType = VectorType::get( srcType.getShape(), - operand.getType().cast().getElementType()); + cast(operand.getType()).getElementType()); srcValues.push_back(rewriter.create( operand.getLoc(), vectorType, operand, rewriter.getI64ArrayAttr(invOrder))); @@ -539,7 +539,7 @@ auto vectorType = VectorType::get( srcType.getShape(), - op->getResultTypes()[0].cast().getElementType()); + cast(op->getResultTypes()[0]).getElementType()); Operation *elementwiseOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, vectorType, op->getAttrs()); @@ -693,7 +693,7 @@ } SmallVector dims = - llvm::to_vector<4>(extractOp.getType().cast().getShape()); + llvm::to_vector<4>(cast(extractOp.getType()).getShape()); dims.back() = dims.back() / expandRatio; VectorType newExtractType = VectorType::get(dims, castSrcType.getElementType()); @@ -996,7 +996,7 @@ LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override { auto dstType = op.getType(); - if (dstType.cast().isScalable()) + if (cast(dstType).isScalable()) return failure(); int64_t rank = dstType.getRank(); if (rank > 1) @@ -1026,7 +1026,7 @@ if (readOp.getMask()) return failure(); - auto srcType = readOp.getSource().getType().dyn_cast(); + auto srcType = dyn_cast(readOp.getSource().getType()); if (!srcType || !srcType.hasStaticShape()) return failure(); @@ -1060,13 +1060,13 @@ MemRefType resultMemrefType; MemRefLayoutAttrInterface layout = srcType.getLayout(); - if (layout.isa() && layout.isIdentity()) { + if (isa(layout) && layout.isIdentity()) { resultMemrefType = MemRefType::get( srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), nullptr, srcType.getMemorySpace()); } else { MemRefLayoutAttrInterface updatedLayout; - if (auto strided = layout.dyn_cast()) { + if (auto strided = dyn_cast(layout)) { auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop)); updatedLayout = StridedLayoutAttr::get(strided.getContext(), @@ -1099,7 +1099,7 @@ loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(), strides); auto permMap = getTransferMinorIdentityMap( - rankedReducedView.getType().cast(), resultTargetVecType); + cast(rankedReducedView.getType()), resultTargetVecType); Value result = rewriter.create( loc, resultTargetVecType, rankedReducedView, readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -316,7 +316,7 @@ auto targetShape = getTargetShape(options, contractOp); if (!targetShape) return failure(); - auto dstVecType = contractOp.getResultType().cast(); + auto dstVecType = cast(contractOp.getResultType()); SmallVector originalSize = *contractOp.getShapeForUnroll(); Location loc = contractOp.getLoc(); @@ -491,7 +491,7 @@ auto targetShape = getTargetShape(options, op); if (!targetShape) return failure(); - auto dstVecType = op->getResult(0).getType().cast(); + auto dstVecType = cast(op->getResult(0).getType()); SmallVector originalSize = *cast(op).getShapeForUnroll(); SmallVector ratio = *computeShapeRatio(originalSize, *targetShape); @@ -512,7 +512,7 @@ getVectorOffset(ratioStrides, i, *targetShape); SmallVector extractOperands; for (OpOperand &operand : op->getOpOperands()) { - auto vecType = operand.get().getType().template dyn_cast(); + auto vecType = dyn_cast(operand.get().getType()); if (!vecType) { extractOperands.push_back(operand.get()); continue; diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -36,9 +36,9 @@ /// the type of `source`. Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) { - if (source.getType().isa()) + if (isa(source.getType())) return b.createOrFold(loc, source, dim); - if (source.getType().isa()) + if (isa(source.getType())) return b.createOrFold(loc, source, dim); llvm_unreachable("Expected MemRefType or TensorType"); } @@ -166,7 +166,7 @@ } return false; } else if (op.getNumResults() == 1) { - if (auto v = op.getResult(0).getType().dyn_cast()) { + if (auto v = dyn_cast(op.getResult(0).getType())) { superVectorType = v; } else { // Not a vector type. diff --git a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp --- a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp @@ -266,7 +266,7 @@ SmallVector transp; for (auto attr : op.getTransp()) - transp.push_back(attr.cast().getInt()); + transp.push_back(cast(attr).getInt()); // Check whether the two source vector dimensions that are greater than one // must be transposed with each other so that we can apply one of the 2-D diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp @@ -22,11 +22,11 @@ /// Extracts the "main" vector element type from the given X86Vector operation. template static Type getSrcVectorElementType(OpTy op) { - return op.getSrc().getType().template cast().getElementType(); + return cast(op.getSrc().getType()).getElementType(); } template <> Type getSrcVectorElementType(Vp2IntersectOp op) { - return op.getA().getType().template cast().getElementType(); + return cast(op.getA().getType()).getElementType(); } namespace { diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -288,30 +288,27 @@ Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction); template <> Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction) { - auto resultType = mainFunction.getFunctionType() - .cast() - .getReturnType() - .dyn_cast(); + auto resultType = dyn_cast(mainFunction.getFunctionType() + .cast() + .getReturnType()); if (!resultType || resultType.getWidth() != 32) return makeStringError("only single i32 function result supported"); return Error::success(); } template <> Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction) { - auto resultType = mainFunction.getFunctionType() - .cast() - .getReturnType() - .dyn_cast(); + auto resultType = dyn_cast(mainFunction.getFunctionType() + .cast() + .getReturnType()); if (!resultType || resultType.getWidth() != 64) return makeStringError("only single i64 function result supported"); return Error::success(); } template <> Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction) { - if (!mainFunction.getFunctionType() - .cast() - .getReturnType() - .isa()) + if (!isa(mainFunction.getFunctionType() + .cast() + .getReturnType())) return makeStringError("only single f32 function result supported"); return Error::success(); } @@ -324,8 +321,7 @@ if (!mainFunction || mainFunction.isExternal()) return makeStringError("entry point not found"); - if (mainFunction.getFunctionType() - .cast() + if (cast(mainFunction.getFunctionType()) .getNumParams() != 0) return makeStringError("function inputs not supported"); diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp --- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp +++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp @@ -37,7 +37,7 @@ static unsigned getIndexBitwidth(DataLayoutEntryListRef params) { if (params.empty()) return 64; - auto attr = params.front().getValue().cast(); + auto attr = cast(params.front().getValue()); return attr.getValue().getZExtValue(); } @@ -51,10 +51,10 @@ unsigned mlir::detail::getDefaultTypeSizeInBits(Type type, const DataLayout &dataLayout, DataLayoutEntryListRef params) { - if (type.isa()) + if (isa(type)) return type.getIntOrFloatBitWidth(); - if (auto ctype = type.dyn_cast()) { + if (auto ctype = dyn_cast(type)) { auto et = ctype.getElementType(); auto innerAlignment = getDefaultPreferredAlignment(et, dataLayout, params) * 8; @@ -66,7 +66,7 @@ } // Index is an integer of some bitwidth. - if (type.isa()) + if (isa(type)) return dataLayout.getTypeSizeInBits( IntegerType::get(type.getContext(), getIndexBitwidth(params))); @@ -75,12 +75,12 @@ // there is no bit-packing at the moment element sizes are taken in bytes and // multiplied with 8 bits. // TODO: make this extensible. - if (auto vecType = type.dyn_cast()) + if (auto vecType = dyn_cast(type)) return vecType.getNumElements() / vecType.getShape().back() * llvm::PowerOf2Ceil(vecType.getShape().back()) * dataLayout.getTypeSize(vecType.getElementType()) * 8; - if (auto typeInterface = type.dyn_cast()) + if (auto typeInterface = dyn_cast(type)) return typeInterface.getTypeSizeInBits(dataLayout, params); reportMissingDataLayout(type); @@ -104,7 +104,7 @@ static unsigned extractABIAlignment(DataLayoutEntryInterface entry) { auto values = - entry.getValue().cast().getValues(); + cast(entry.getValue()).getValues(); return *values.begin() / 8u; } @@ -134,24 +134,24 @@ Type type, const DataLayout &dataLayout, ArrayRef params) { // Natural alignment is the closest power-of-two number above. - if (type.isa()) + if (isa(type)) return llvm::PowerOf2Ceil(dataLayout.getTypeSize(type)); - if (auto fltType = type.dyn_cast()) + if (auto fltType = dyn_cast(type)) return getFloatTypeABIAlignment(fltType, dataLayout, params); // Index is an integer of some bitwidth. - if (type.isa()) + if (isa(type)) return dataLayout.getTypeABIAlignment( IntegerType::get(type.getContext(), getIndexBitwidth(params))); - if (auto intType = type.dyn_cast()) + if (auto intType = dyn_cast(type)) return getIntegerTypeABIAlignment(intType, params); - if (auto ctype = type.dyn_cast()) + if (auto ctype = dyn_cast(type)) return getDefaultABIAlignment(ctype.getElementType(), dataLayout, params); - if (auto typeInterface = type.dyn_cast()) + if (auto typeInterface = dyn_cast(type)) return typeInterface.getABIAlignment(dataLayout, params); reportMissingDataLayout(type); @@ -159,7 +159,7 @@ static unsigned extractPreferredAlignment(DataLayoutEntryInterface entry) { auto values = - entry.getValue().cast().getValues(); + cast(entry.getValue()).getValues(); return *std::next(values.begin(), values.size() - 1) / 8u; } @@ -187,27 +187,27 @@ Type type, const DataLayout &dataLayout, ArrayRef params) { // Preferred alignment is same as natural for floats and vectors. - if (type.isa()) + if (isa(type)) return dataLayout.getTypeABIAlignment(type); - if (auto fltType = type.dyn_cast()) + if (auto fltType = dyn_cast(type)) return getFloatTypePreferredAlignment(fltType, dataLayout, params); // Preferred alignment is the closest power-of-two number above for integers // (ABI alignment may be smaller). - if (auto intType = type.dyn_cast()) + if (auto intType = dyn_cast(type)) return getIntegerTypePreferredAlignment(intType, dataLayout, params); - if (type.isa()) { + if (isa(type)) { return dataLayout.getTypePreferredAlignment( IntegerType::get(type.getContext(), getIndexBitwidth(params))); } - if (auto ctype = type.dyn_cast()) + if (auto ctype = dyn_cast(type)) return getDefaultPreferredAlignment(ctype.getElementType(), dataLayout, params); - if (auto typeInterface = type.dyn_cast()) + if (auto typeInterface = dyn_cast(type)) return typeInterface.getPreferredAlignment(dataLayout, params); reportMissingDataLayout(type); @@ -232,7 +232,7 @@ if (entry == DataLayoutEntryInterface()) return 0; - auto value = entry.getValue().cast(); + auto value = cast(entry.getValue()); return value.getValue().getZExtValue(); } @@ -543,19 +543,19 @@ for (const auto &kvp : types) { auto sampleType = kvp.second.front().getKey().get(); - if (sampleType.isa()) { + if (isa(sampleType)) { assert(kvp.second.size() == 1 && "expected one data layout entry for non-parametric 'index' type"); - if (!kvp.second.front().getValue().isa()) + if (!isa(kvp.second.front().getValue())) return emitError(loc) << "expected integer attribute in the data layout entry for " << sampleType; continue; } - if (sampleType.isa()) { + if (isa(sampleType)) { for (DataLayoutEntryInterface entry : kvp.second) { - auto value = entry.getValue().dyn_cast(); + auto value = dyn_cast(entry.getValue()); if (!value || !value.getElementType().isSignlessInteger(32)) { emitError(loc) << "expected a dense i32 elements attribute in the " "data layout entry " @@ -587,7 +587,7 @@ if (isa(&sampleType.getDialect())) return emitError(loc) << "unexpected data layout for a built-in type"; - auto dlType = sampleType.dyn_cast(); + auto dlType = dyn_cast(sampleType); if (!dlType) return emitError(loc) << "data layout specified for a type that does not support it"; diff --git a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp --- a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp +++ b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp @@ -29,9 +29,9 @@ SmallVector outputBufferOperands, outputTensorOperands; for (OpOperand *operand : dstStyleOp.getDpsInitOperands()) { Type type = operand->get().getType(); - if (type.isa()) { + if (isa(type)) { outputBufferOperands.push_back(operand); - } else if (type.isa()) { + } else if (isa(type)) { outputTensorOperands.push_back(operand); } else { return op->emitOpError("expected that operand #") diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp --- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp +++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp @@ -30,7 +30,7 @@ unsigned ConstantIntRanges::getStorageBitwidth(Type type) { if (type.isIndex()) return IndexType::kInternalStorageBitWidth; - if (auto integerType = type.dyn_cast()) + if (auto integerType = dyn_cast(type)) return integerType.getWidth(); // Non-integer types have their bounds stored in width 0 `APInt`s. return 0; diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -36,7 +36,7 @@ // a correct result. int64_t resultIdx = 0; for (OpResult result : op->getResults()) { - auto shapedType = result.getType().dyn_cast(); + auto shapedType = dyn_cast(result.getType()); if (!shapedType) continue; if (!shapedType.hasRank()) { @@ -69,7 +69,7 @@ if (val.isNull()) return false; if (auto t = val.dyn_cast()) - return t.cast().hasRank(); + return cast(t).hasRank(); if (val.is()) return true; return val.get()->hasRank(); @@ -79,7 +79,7 @@ if (val.isNull()) return nullptr; if (auto t = val.dyn_cast()) - return t.cast().getElementType(); + return cast(t).getElementType(); if (val.is()) return nullptr; return val.get()->getElementType(); @@ -88,10 +88,10 @@ void ShapeAdaptor::getDims(SmallVectorImpl &res) const { assert(hasRank()); if (auto t = val.dyn_cast()) { - ArrayRef vals = t.cast().getShape(); + ArrayRef vals = cast(t).getShape(); res.assign(vals.begin(), vals.end()); } else if (auto attr = val.dyn_cast()) { - auto dattr = attr.cast(); + auto dattr = cast(attr); res.clear(); res.reserve(dattr.size()); for (auto it : dattr.getValues()) @@ -111,9 +111,9 @@ int64_t ShapeAdaptor::getDimSize(int index) const { assert(hasRank()); if (auto t = val.dyn_cast()) - return t.cast().getDimSize(index); + return cast(t).getDimSize(index); if (auto attr = val.dyn_cast()) - return attr.cast() + return cast(attr) .getValues()[index] .getSExtValue(); auto *stc = val.get(); @@ -123,9 +123,9 @@ int64_t ShapeAdaptor::getRank() const { assert(hasRank()); if (auto t = val.dyn_cast()) - return t.cast().getRank(); + return cast(t).getRank(); if (auto attr = val.dyn_cast()) - return attr.cast().size(); + return cast(attr).size(); return val.get()->getDims().size(); } @@ -134,9 +134,9 @@ return false; if (auto t = val.dyn_cast()) - return t.cast().hasStaticShape(); + return cast(t).hasStaticShape(); if (auto attr = val.dyn_cast()) { - auto dattr = attr.cast(); + auto dattr = cast(attr); for (auto index : dattr.getValues()) if (ShapedType::isDynamic(index.getSExtValue())) return false; @@ -150,10 +150,10 @@ assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); if (auto t = val.dyn_cast()) - return t.cast().getNumElements(); + return cast(t).getNumElements(); if (auto attr = val.dyn_cast()) { - auto dattr = attr.cast(); + auto dattr = cast(attr); int64_t num = 1; for (auto index : dattr.getValues()) { num *= index.getZExtValue(); diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -34,7 +34,7 @@ } // Case 2: Check for IntegerAttr. Attribute attr = ofr.dyn_cast(); - if (auto intAttr = attr.dyn_cast_or_null()) + if (auto intAttr = dyn_cast_or_null(attr)) return intAttr.getValue().getSExtValue(); return std::nullopt; } @@ -137,8 +137,8 @@ std::optional dim) const { #ifndef NDEBUG assertValidValueDim(value, dim); - assert((value.isa() || - value.cast().getOwner()->isEntryBlock()) && + assert((isa(value) || + cast(value).getOwner()->isEntryBlock()) && "unstructured control flow is not supported"); #endif // NDEBUG @@ -149,7 +149,7 @@ } static Operation *getOwnerOfValue(Value value) { - if (auto bbArg = value.dyn_cast()) + if (auto bbArg = dyn_cast(value)) return bbArg.getOwner()->getParentOp(); return value.getDefiningOp(); } diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -402,7 +402,7 @@ .Case( [](Type) { return PDLValue::Kind::Operation; }) .Case([](pdl::RangeType rangeTy) { - if (rangeTy.getElementType().isa()) + if (isa(rangeTy.getElementType())) return PDLValue::Kind::TypeRange; return PDLValue::Kind::ValueRange; }) @@ -538,11 +538,11 @@ ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; auto processRewriterValue = [&](Value val) { valueToMemIndex.try_emplace(val, index++); - if (pdl::RangeType rangeType = val.getType().dyn_cast()) { + if (pdl::RangeType rangeType = dyn_cast(val.getType())) { Type elementTy = rangeType.getElementType(); - if (elementTy.isa()) + if (isa(elementTy)) valueToRangeIndex.try_emplace(val, typeRangeIndex++); - else if (elementTy.isa()) + else if (isa(elementTy)) valueToRangeIndex.try_emplace(val, valueRangeIndex++); } }; @@ -611,13 +611,13 @@ /*dummyValue*/ 0); // Check to see if this value is a range type. - if (auto rangeTy = value.getType().dyn_cast()) { + if (auto rangeTy = dyn_cast(value.getType())) { Type eleType = rangeTy.getElementType(); - if (eleType.isa()) + if (isa(eleType)) defRangeIt->second.opRangeIndex = 0; - else if (eleType.isa()) + else if (isa(eleType)) defRangeIt->second.typeRangeIndex = 0; - else if (eleType.isa()) + else if (isa(eleType)) defRangeIt->second.valueRangeIndex = 0; } }; @@ -792,14 +792,14 @@ #endif // Range results also need to append the range storage index. - if (result.getType().isa()) + if (isa(result.getType())) writer.append(getRangeStorageIndex(result)); writer.append(result); } } void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { Value lhs = op.getLhs(); - if (lhs.getType().isa()) { + if (isa(lhs.getType())) { writer.append(OpCode::AreRangesEqual); writer.appendPDLValueKind(lhs); writer.append(op.getLhs(), op.getRhs(), op.getSuccessors()); @@ -945,7 +945,7 @@ writer.append(OpCode::GetOperands, index.value_or(std::numeric_limits::max()), op.getInputOp()); - if (result.getType().isa()) + if (isa(result.getType())) writer.append(getRangeStorageIndex(result)); else writer.append(std::numeric_limits::max()); @@ -965,7 +965,7 @@ writer.append(OpCode::GetResults, index.value_or(std::numeric_limits::max()), op.getInputOp()); - if (result.getType().isa()) + if (isa(result.getType())) writer.append(getRangeStorageIndex(result)); else writer.append(std::numeric_limits::max()); @@ -979,7 +979,7 @@ } void Generator::generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer) { - if (op.getType().isa()) { + if (isa(op.getType())) { Value result = op.getResult(); writer.append(OpCode::GetValueRangeTypes, result, getRangeStorageIndex(result), op.getValue()); @@ -1016,7 +1016,7 @@ void Generator::generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer) { auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) { - return OperationName(attr.cast().getValue(), ctx); + return OperationName(cast(attr).getValue(), ctx); }); writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases, op.getSuccessors()); @@ -1566,7 +1566,7 @@ Attribute rhs = read(); LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); - selectJump(*lhs == rhs.cast().getAsValueRange()); + selectJump(*lhs == cast(rhs).getAsValueRange()); } void ByteCodeExecutor::executeContinue() { @@ -1581,7 +1581,7 @@ LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n"); unsigned memIndex = read(); unsigned rangeIndex = read(); - ArrayAttr typesAttr = read().cast(); + ArrayAttr typesAttr = cast(read()); LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); assignRangeToMemory(typesAttr.getAsValueRange(), memIndex, @@ -1743,7 +1743,7 @@ unsigned memIndex = read(); Attribute attr = read(); Type type; - if (auto typedAttr = attr.dyn_cast()) + if (auto typedAttr = dyn_cast(attr)) type = typedAttr.getType(); LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -190,7 +190,7 @@ // the FuncOp. if (emitter.shouldDeclareVariablesAtTop()) { // Skip the assignment if the emitc.constant has no value. - if (auto oAttr = value.dyn_cast()) { + if (auto oAttr = dyn_cast(value)) { if (oAttr.getValue().empty()) return success(); } @@ -201,7 +201,7 @@ } // Emit a variable declaration for an emitc.constant op without value. - if (auto oAttr = value.dyn_cast()) { + if (auto oAttr = dyn_cast(value)) { if (oAttr.getValue().empty()) // The semicolon gets printed by the emitOperation function. return emitter.emitVariableDeclaration(result, @@ -333,7 +333,7 @@ os << callOp.getCallee(); auto emitArgs = [&](Attribute attr) -> LogicalResult { - if (auto t = attr.dyn_cast()) { + if (auto t = dyn_cast(attr)) { // Index attributes are treated specially as operand index. if (t.getType().isIndex()) { int64_t idx = t.getInt(); @@ -759,11 +759,11 @@ }; // Print floating point attributes. - if (auto fAttr = attr.dyn_cast()) { + if (auto fAttr = dyn_cast(attr)) { printFloat(fAttr.getValue()); return success(); } - if (auto dense = attr.dyn_cast()) { + if (auto dense = dyn_cast(attr)) { os << '{'; interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); }); os << '}'; @@ -771,21 +771,19 @@ } // Print integer attributes. - if (auto iAttr = attr.dyn_cast()) { - if (auto iType = iAttr.getType().dyn_cast()) { + if (auto iAttr = dyn_cast(attr)) { + if (auto iType = dyn_cast(iAttr.getType())) { printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness())); return success(); } - if (auto iType = iAttr.getType().dyn_cast()) { + if (auto iType = dyn_cast(iAttr.getType())) { printInt(iAttr.getValue(), false); return success(); } } - if (auto dense = attr.dyn_cast()) { - if (auto iType = dense.getType() - .cast() - .getElementType() - .dyn_cast()) { + if (auto dense = dyn_cast(attr)) { + if (auto iType = dyn_cast( + dense.getType().cast().getElementType())) { os << '{'; interleaveComma(dense, os, [&](const APInt &val) { printInt(val, shouldMapToUnsigned(iType.getSignedness())); @@ -793,10 +791,8 @@ os << '}'; return success(); } - if (auto iType = dense.getType() - .cast() - .getElementType() - .dyn_cast()) { + if (auto iType = dyn_cast( + dense.getType().cast().getElementType())) { os << '{'; interleaveComma(dense, os, [&](const APInt &val) { printInt(val, false); }); @@ -806,13 +802,13 @@ } // Print opaque attributes. - if (auto oAttr = attr.dyn_cast()) { + if (auto oAttr = dyn_cast(attr)) { os << oAttr.getValue(); return success(); } // Print symbolic reference attributes. - if (auto sAttr = attr.dyn_cast()) { + if (auto sAttr = dyn_cast(attr)) { if (sAttr.getNestedReferences().size() > 1) return emitError(loc, "attribute has more than 1 nested reference"); os << sAttr.getRootReference().getValue(); @@ -820,7 +816,7 @@ } // Print type attributes. - if (auto type = attr.dyn_cast()) + if (auto type = dyn_cast(attr)) return emitType(loc, type.getValue()); return emitError(loc, "cannot emit attribute: ") << attr; @@ -957,7 +953,7 @@ } LogicalResult CppEmitter::emitType(Location loc, Type type) { - if (auto iType = type.dyn_cast()) { + if (auto iType = dyn_cast(type)) { switch (iType.getWidth()) { case 1: return (os << "bool"), success(); @@ -973,7 +969,7 @@ return emitError(loc, "cannot emit integer type ") << type; } } - if (auto fType = type.dyn_cast()) { + if (auto fType = dyn_cast(type)) { switch (fType.getWidth()) { case 32: return (os << "float"), success(); @@ -983,9 +979,9 @@ return emitError(loc, "cannot emit float type ") << type; } } - if (auto iType = type.dyn_cast()) + if (auto iType = dyn_cast(type)) return (os << "size_t"), success(); - if (auto tType = type.dyn_cast()) { + if (auto tType = dyn_cast(type)) { if (!tType.hasRank()) return emitError(loc, "cannot emit unranked tensor type"); if (!tType.hasStaticShape()) @@ -1001,13 +997,13 @@ os << ">"; return success(); } - if (auto tType = type.dyn_cast()) + if (auto tType = dyn_cast(type)) return emitTupleType(loc, tType.getTypes()); - if (auto oType = type.dyn_cast()) { + if (auto oType = dyn_cast(type)) { os << oType.getValue(); return success(); } - if (auto pType = type.dyn_cast()) { + if (auto pType = dyn_cast(type)) { if (failed(emitType(loc, pType.getPointee()))) return failure(); os << "*"; diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp --- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp @@ -21,8 +21,8 @@ /// A utility walker that interrupts if the operation has valid debug /// information. static WalkResult interruptIfValidLocation(Operation *op) { - return op->getLoc().isa() ? WalkResult::advance() - : WalkResult::interrupt(); + return isa(op->getLoc()) ? WalkResult::advance() + : WalkResult::interrupt(); } DebugTranslation::DebugTranslation(Operation *module, llvm::Module &llvmModule) @@ -45,7 +45,7 @@ if (auto targetTripleAttr = module->getAttr(LLVM::LLVMDialect::getTargetTripleAttrName())) { auto targetTriple = - llvm::Triple(targetTripleAttr.cast().getValue()); + llvm::Triple(cast(targetTripleAttr).getValue()); if (targetTriple.isKnownWindowsMSVCEnvironment()) { // Dwarf debugging files will be generated by default, unless "CodeView" // is set explicitly. Windows/MSVC should use CodeView instead. @@ -68,8 +68,8 @@ const bool hasCallWithoutDebugInfo = func.walk([&](LLVM::CallOp call) { return call.getLoc()->walk([](Location l) { - return l.isa() ? WalkResult::interrupt() - : WalkResult::advance(); + return isa(l) ? WalkResult::interrupt() + : WalkResult::advance(); }); }) .wasInterrupted(); @@ -273,7 +273,7 @@ DebugTranslation::translateLoc(Location loc, llvm::DILocalScope *scope, const llvm::DILocation *inlinedAt) { // LLVM doesn't have a representation for unknown. - if (!scope || loc.isa()) + if (!scope || isa(loc)) return nullptr; // Check for a cached instance. @@ -282,12 +282,12 @@ return existingIt->second; const llvm::DILocation *llvmLoc = nullptr; - if (auto callLoc = loc.dyn_cast()) { + if (auto callLoc = dyn_cast(loc)) { // For callsites, the caller is fed as the inlinedAt for the callee. const auto *callerLoc = translateLoc(callLoc.getCaller(), scope, inlinedAt); llvmLoc = translateLoc(callLoc.getCallee(), scope, callerLoc); - } else if (auto fileLoc = loc.dyn_cast()) { + } else if (auto fileLoc = dyn_cast(loc)) { llvm::DILocalScope *locationScope = scope; // Only construct a new DIFile when no local scope is present. This // prioritizes existing DI information when it's present. @@ -300,12 +300,12 @@ fileLoc.getColumn(), locationScope, const_cast(inlinedAt)); - } else if (auto fusedLoc = loc.dyn_cast()) { + } else if (auto fusedLoc = dyn_cast(loc)) { ArrayRef locations = fusedLoc.getLocations(); // Check for a scope encoded with the location. if (auto scopedAttr = - fusedLoc.getMetadata().dyn_cast_or_null()) + dyn_cast_or_null(fusedLoc.getMetadata())) scope = translate(scopedAttr); // For fused locations, merge each of the nodes. @@ -315,10 +315,10 @@ llvmLoc, translateLoc(locIt, scope, inlinedAt)); } - } else if (auto nameLoc = loc.dyn_cast()) { + } else if (auto nameLoc = dyn_cast(loc)) { llvmLoc = translateLoc(nameLoc.getChildLoc(), scope, inlinedAt); - } else if (auto opaqueLoc = loc.dyn_cast()) { + } else if (auto opaqueLoc = dyn_cast(loc)) { llvmLoc = translateLoc(opaqueLoc.getFallbackLocation(), scope, inlinedAt); } else { llvm_unreachable("unknown location kind"); diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -231,9 +231,9 @@ Attribute attr = it.value(); if (!attr) continue; - DictionaryAttr dAttr = attr.cast(); + DictionaryAttr dAttr = cast(attr); TypeAttr tAttr = - dAttr.get(InlineAsmOp::getElementTypeAttrName()).cast(); + cast(dAttr.get(InlineAsmOp::getElementTypeAttrName())); llvm::AttrBuilder b(moduleTranslation.getLLVMContext()); llvm::Type *ty = moduleTranslation.convertType(tAttr.getValue()); b.addTypeAttr(llvm::Attribute::ElementType, ty); diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -162,7 +162,7 @@ ->addOperand(llvmMetadataNode); }; if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) { - if (!attribute.getValue().dyn_cast()) + if (!dyn_cast(attribute.getValue())) return failure(); SmallVector values = extractFromI64ArrayAttr(attribute.getValue()); @@ -172,7 +172,7 @@ if (values.size() > 2) generateMetadata(values[2], NVVM::NVVMDialect::getMaxntidZName()); } else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) { - if (!attribute.getValue().dyn_cast()) + if (!dyn_cast(attribute.getValue())) return failure(); SmallVector values = extractFromI64ArrayAttr(attribute.getValue()); @@ -183,10 +183,10 @@ generateMetadata(values[2], NVVM::NVVMDialect::getReqntidZName()); } else if (attribute.getName() == NVVM::NVVMDialect::getMinctasmAttrName()) { - auto value = attribute.getValue().dyn_cast(); + auto value = dyn_cast(attribute.getValue()); generateMetadata(value.getInt(), "minctasm"); } else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) { - auto value = attribute.getValue().dyn_cast(); + auto value = dyn_cast(attribute.getValue()); generateMetadata(value.getInt(), "maxnreg"); } else if (attribute.getName() == NVVM::NVVMDialect::getKernelFuncAttrName()) { diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp @@ -108,7 +108,7 @@ dataPtr = builder.CreateExtractValue(dataValue, kPtrPosInDataDescriptor); dataSize = builder.CreateExtractValue(dataValue, kSizePosInDataDescriptor); - } else if (data.getType().isa()) { + } else if (isa(data.getType())) { dataPtrBase = dataValue; dataPtr = dataValue; dataSize = accBuilder->getSizeInBytes(dataValue); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -367,7 +367,7 @@ if (criticalOp.getNameAttr()) { // The verifiers in OpenMP Dialect guarentee that all the pointers are // non-null - auto symbolRef = criticalOp.getNameAttr().cast(); + auto symbolRef = cast(criticalOp.getNameAttr()); auto criticalDeclareOp = SymbolTable::lookupNearestSymbolFrom(criticalOp, symbolRef); @@ -389,7 +389,7 @@ for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) { if (container.getReductionVars()[i] != reduction.getAccumulator()) continue; - reductionSymbol = (*container.getReductions())[i].cast(); + reductionSymbol = cast((*container.getReductions())[i]); break; } assert(reductionSymbol && @@ -705,7 +705,7 @@ llvm::zip(taskOp.getDependVars(), taskOp.getDepends()->getValue())) { llvm::omp::RTLDependenceKindTy type; switch ( - std::get<1>(dep).cast().getValue()) { + cast(std::get<1>(dep)).getValue()) { case mlir::omp::ClauseTaskDepend::taskdependin: type = llvm::omp::RTLDependenceKindTy::DepIn; break; @@ -1379,7 +1379,7 @@ llvm::Value *mapOpPtr; llvm::Value *mapOpSize; - if (mapOp.getType().isa()) { + if (isa(mapOp.getType())) { mapOpPtrBase = mapOpValue; mapOpPtr = mapOpValue; mapOpSize = ompBuilder->getSizeInBytes(mapOpValue); @@ -1410,7 +1410,7 @@ {builder.getInt32(0), builder.getInt32(index)}); builder.CreateStore(mapOpSize, sizeGEP); - mapTypeFlags.push_back(mapTypeOp.dyn_cast().getInt()); + mapTypeFlags.push_back(dyn_cast(mapTypeOp).getInt()); llvm::Constant *mapName = mlir::LLVM::createMappingInformation(mapOp.getLoc(), *ompBuilder); mapNames.push_back(mapName); @@ -1445,7 +1445,7 @@ if (auto constOp = mlir::dyn_cast( devId.getDefiningOp())) if (auto intAttr = - constOp.getValue().dyn_cast()) + dyn_cast(constOp.getValue())) deviceID = intAttr.getInt(); numMapOperands = dataOp.getMapOperands().size(); @@ -1464,7 +1464,7 @@ if (auto constOp = mlir::dyn_cast( devId.getDefiningOp())) if (auto intAttr = - constOp.getValue().dyn_cast()) + dyn_cast(constOp.getValue())) deviceID = intAttr.getInt(); numMapOperands = enterDataOp.getMapOperands().size(); @@ -1483,7 +1483,7 @@ if (auto constOp = mlir::dyn_cast( devId.getDefiningOp())) if (auto intAttr = - constOp.getValue().dyn_cast()) + dyn_cast(constOp.getValue())) deviceID = intAttr.getInt(); numMapOperands = exitDataOp.getMapOperands().size(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMPCommon.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMPCommon.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMPCommon.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMPCommon.cpp @@ -16,7 +16,7 @@ mlir::LLVM::createSourceLocStrFromLocation(Location loc, llvm::OpenMPIRBuilder &builder, StringRef name, uint32_t &strLen) { - if (auto fileLoc = loc.dyn_cast()) { + if (auto fileLoc = dyn_cast(loc)) { StringRef fileName = fileLoc.getFilename(); unsigned lineNo = fileLoc.getLine(); unsigned colNo = fileLoc.getColumn(); @@ -32,7 +32,7 @@ mlir::LLVM::createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder) { uint32_t strLen; - if (auto nameLoc = loc.dyn_cast()) { + if (auto nameLoc = dyn_cast(loc)) { StringRef name = nameLoc.getName(); return createSourceLocStrFromLocation(nameLoc.getChildLoc(), builder, name, strLen); diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp @@ -109,7 +109,7 @@ auto func = dyn_cast(op); if (!func) return failure(); - auto value = attribute.getValue().dyn_cast(); + auto value = dyn_cast(attribute.getValue()); if (!value) return failure(); @@ -125,7 +125,7 @@ auto func = dyn_cast(op); if (!func) return failure(); - auto value = attribute.getValue().dyn_cast(); + auto value = dyn_cast(attribute.getValue()); if (!value) return failure(); @@ -142,7 +142,7 @@ auto func = dyn_cast(op); if (!func) return failure(); - auto value = attribute.getValue().dyn_cast(); + auto value = dyn_cast(attribute.getValue()); if (!value) return failure(); llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); diff --git a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp --- a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp @@ -190,7 +190,7 @@ void LoopAnnotationConversion::convertLocation(FusedLoc location) { auto localScopeAttr = - location.getMetadata().dyn_cast_or_null(); + dyn_cast_or_null(location.getMetadata()); if (!localScopeAttr) return; auto *localScope = dyn_cast( diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -630,7 +630,7 @@ if (!type) return nullptr; - if (type.isa()) + if (isa(type)) return type; // LLVM vectors can only contain scalars. @@ -647,12 +647,12 @@ } // LLVM arrays can contain other arrays or vectors. - if (auto arrayType = type.dyn_cast()) { + if (auto arrayType = dyn_cast(type)) { // Recover the nested array shape. SmallVector shape; shape.push_back(arrayType.getNumElements()); - while (arrayType.getElementType().isa()) { - arrayType = arrayType.getElementType().cast(); + while (isa(arrayType.getElementType())) { + arrayType = cast(arrayType.getElementType()); shape.push_back(arrayType.getNumElements()); } @@ -710,12 +710,12 @@ // Convert constant data to a dense elements attribute. if (auto *cd = dyn_cast(value)) { Type type = convertType(cd->getElementType()); - auto attrType = getStdTypeForAttr(convertType(cd->getType())) - .dyn_cast_or_null(); + auto attrType = dyn_cast_or_null( + getStdTypeForAttr(convertType(cd->getType()))); if (!attrType) return nullptr; - if (type.isa()) { + if (isa(type)) { SmallVector values; values.reserve(cd->getNumElements()); for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i) @@ -723,7 +723,7 @@ return DenseElementsAttr::get(attrType, values); } - if (type.isa()) { + if (isa(type)) { SmallVector values; values.reserve(cd->getNumElements()); for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i) @@ -737,8 +737,8 @@ // Unpack constant aggregates to create dense elements attribute whenever // possible. Return nullptr (failure) otherwise. if (isa(value)) { - auto outerType = getStdTypeForAttr(convertType(value->getType())) - .dyn_cast_or_null(); + auto outerType = dyn_cast_or_null( + getStdTypeForAttr(convertType(value->getType()))); if (!outerType) return nullptr; @@ -746,8 +746,8 @@ SmallVector shape; for (unsigned i = 0, e = value->getNumOperands(); i < e; ++i) { - auto nested = getConstantAsAttr(value->getAggregateElement(i)) - .dyn_cast_or_null(); + auto nested = dyn_cast_or_null( + getConstantAsAttr(value->getAggregateElement(i))); if (!nested) return nullptr; @@ -921,7 +921,7 @@ // Convert constants that can be represented as attributes. if (Attribute attr = getConstantAsAttr(constant)) { Type type = convertType(constant->getType()); - if (auto symbolRef = attr.dyn_cast()) { + if (auto symbolRef = dyn_cast(attr)) { return builder.create(loc, type, symbolRef.getValue()) .getResult(); } @@ -998,7 +998,7 @@ // Generate an UndefOp as root value and insert the aggregate elements. Type rootType = convertType(constant->getType()); - bool isArrayOrStruct = rootType.isa(); + bool isArrayOrStruct = isa(rootType); assert((isArrayOrStruct || LLVM::isCompatibleVectorType(rootType)) && "unrecognized aggregate type"); Value root = builder.create(loc, rootType); @@ -1558,7 +1558,7 @@ clearBlockAndValueMapping(); auto functionType = - convertType(func->getFunctionType()).dyn_cast(); + dyn_cast(convertType(func->getFunctionType())); if (func->isIntrinsic() && iface.isConvertibleIntrinsic(func->getIntrinsicID())) return success(); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -73,7 +73,7 @@ if (!key) continue; if (key.getValue() == DLTIDialect::kDataLayoutEndiannessKey) { - auto value = entry.getValue().cast(); + auto value = cast(entry.getValue()); bool isLittleEndian = value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle; layoutStream << "-" << (isLittleEndian ? "e" : "E"); @@ -81,7 +81,7 @@ continue; } if (key.getValue() == DLTIDialect::kDataLayoutAllocaMemorySpaceKey) { - auto value = entry.getValue().cast(); + auto value = cast(entry.getValue()); uint64_t space = value.getValue().getZExtValue(); // Skip the default address space. if (space == 0) @@ -91,7 +91,7 @@ continue; } if (key.getValue() == DLTIDialect::kDataLayoutStackAlignmentKey) { - auto value = entry.getValue().cast(); + auto value = cast(entry.getValue()); uint64_t alignment = value.getValue().getZExtValue(); // Skip the default stack alignment. if (alignment == 0) @@ -112,14 +112,14 @@ if (!type) continue; // Data layout for the index type is irrelevant at this point. - if (type.isa()) + if (isa(type)) continue; layoutStream << "-"; LogicalResult result = llvm::TypeSwitch(type) .Case([&](Type type) -> LogicalResult { - if (auto intType = type.dyn_cast()) { + if (auto intType = dyn_cast(type)) { if (intType.getSignedness() != IntegerType::Signless) return emitError(*loc) << "unsupported data layout for non-signless integer " @@ -250,7 +250,7 @@ // Compute the shape of all dimensions but the innermost. Note that the // innermost dimension may be that of the vector element type. - bool hasVectorElementType = type.getElementType().isa(); + bool hasVectorElementType = isa(type.getElementType()); unsigned numAggregates = denseElementsAttr.getNumElements() / (hasVectorElementType ? 1 @@ -261,7 +261,7 @@ // Handle the case of vector splat, LLVM has special support for it. if (denseElementsAttr.isSplat() && - (type.isa() || hasVectorElementType)) { + (isa(type) || hasVectorElementType)) { llvm::Constant *splatValue = LLVM::detail::getLLVMConstant( innermostLLVMType, denseElementsAttr.getSplatValue(), loc, moduleTranslation); @@ -277,8 +277,8 @@ // In case of non-splat, create a constructor for the innermost constant from // a piece of raw data. std::function buildCstData; - if (type.isa()) { - auto vectorElementType = type.getElementType().dyn_cast(); + if (isa(type)) { + auto vectorElementType = dyn_cast(type.getElementType()); if (vectorElementType && vectorElementType.getRank() == 1) { buildCstData = [&](StringRef data) { return llvm::ConstantDataVector::getRaw( @@ -290,7 +290,7 @@ innermostLLVMType); }; } - } else if (type.isa()) { + } else if (isa(type)) { buildCstData = [&](StringRef data) { return llvm::ConstantDataVector::getRaw(data, type.getShape().back(), innermostLLVMType); @@ -326,7 +326,7 @@ if (!attr) return llvm::UndefValue::get(llvmType); if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) { - auto arrayAttr = attr.dyn_cast(); + auto arrayAttr = dyn_cast(attr); if (!arrayAttr || arrayAttr.size() != 2) { emitError(loc, "expected struct type to be a complex number"); return nullptr; @@ -344,11 +344,11 @@ } // For integer types, we allow a mismatch in sizes as the index type in // MLIR might have a different size than the index type in the LLVM module. - if (auto intAttr = attr.dyn_cast()) + if (auto intAttr = dyn_cast(attr)) return llvm::ConstantInt::get( llvmType, intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth())); - if (auto floatAttr = attr.dyn_cast()) { + if (auto floatAttr = dyn_cast(attr)) { if (llvmType != llvm::Type::getFloatingPointTy(llvmType->getContext(), floatAttr.getValue().getSemantics())) { @@ -357,10 +357,10 @@ } return llvm::ConstantFP::get(llvmType, floatAttr.getValue()); } - if (auto funcAttr = attr.dyn_cast()) + if (auto funcAttr = dyn_cast(attr)) return llvm::ConstantExpr::getBitCast( moduleTranslation.lookupFunction(funcAttr.getValue()), llvmType); - if (auto splatAttr = attr.dyn_cast()) { + if (auto splatAttr = dyn_cast(attr)) { llvm::Type *elementType; uint64_t numElements; bool isScalable = false; @@ -401,13 +401,13 @@ // Try using raw elements data if possible. if (llvm::Constant *result = - convertDenseElementsAttr(loc, attr.dyn_cast(), + convertDenseElementsAttr(loc, dyn_cast(attr), llvmType, moduleTranslation)) { return result; } // Fall back to element-by-element construction otherwise. - if (auto elementsAttr = attr.dyn_cast()) { + if (auto elementsAttr = dyn_cast(attr)) { assert(elementsAttr.getShapedType().hasStaticShape()); assert(!elementsAttr.getShapedType().getShape().empty() && "unexpected empty elements attribute shape"); @@ -428,7 +428,7 @@ return result; } - if (auto stringAttr = attr.dyn_cast()) { + if (auto stringAttr = dyn_cast(attr)) { return llvm::ConstantDataArray::get( moduleTranslation.getLLVMContext(), ArrayRef{stringAttr.getValue().data(), @@ -685,7 +685,7 @@ if (op.getValueOrNull()) { // String attributes are treated separately because they cannot appear as // in-function constants and are thus not supported by getLLVMConstant. - if (auto strAttr = op.getValueOrNull().dyn_cast_or_null()) { + if (auto strAttr = dyn_cast_or_null(op.getValueOrNull())) { cst = llvm::ConstantDataArray::getString( llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false); type = cst->getType(); @@ -763,11 +763,10 @@ ctorOp ? llvm::appendToGlobalCtors : llvm::appendToGlobalDtors; for (auto symbolAndPriority : range) { llvm::Function *f = lookupFunction( - std::get<0>(symbolAndPriority).cast().getValue()); - appendGlobalFn( - *llvmModule, f, - std::get<1>(symbolAndPriority).cast().getInt(), - /*Data=*/nullptr); + cast(std::get<0>(symbolAndPriority)).getValue()); + appendGlobalFn(*llvmModule, f, + cast(std::get<1>(symbolAndPriority)).getInt(), + /*Data=*/nullptr); } } @@ -830,20 +829,20 @@ return success(); for (Attribute attr : *attributes) { - if (auto stringAttr = attr.dyn_cast()) { + if (auto stringAttr = dyn_cast(attr)) { if (failed( checkedAddLLVMFnAttribute(loc, llvmFunc, stringAttr.getValue()))) return failure(); continue; } - auto arrayAttr = attr.dyn_cast(); + auto arrayAttr = dyn_cast(attr); if (!arrayAttr || arrayAttr.size() != 2) return emitError(loc) << "expected 'passthrough' to contain string or array attributes"; - auto keyAttr = arrayAttr[0].dyn_cast(); - auto valueAttr = arrayAttr[1].dyn_cast(); + auto keyAttr = dyn_cast(arrayAttr[0]); + auto valueAttr = dyn_cast(arrayAttr[1]); if (!keyAttr || !valueAttr) return emitError(loc) << "expected arrays within 'passthrough' to contain two strings"; @@ -985,7 +984,7 @@ // Convert result attributes. if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) { - DictionaryAttr resultAttrs = allResultAttrs[0].cast(); + DictionaryAttr resultAttrs = cast(allResultAttrs[0]); llvmFunc->addRetAttrs(convertParameterAttrs(resultAttrs)); } @@ -1133,7 +1132,7 @@ return; } - SymbolRefAttr tagRef = tagRefs[0].cast(); + SymbolRefAttr tagRef = cast(tagRefs[0]); llvm::MDNode *node = getTBAANode(op, tagRef); inst->setMetadata(llvm::LLVMContext::MD_tbaa, node); } @@ -1192,7 +1191,7 @@ // The type references are in 1, 3, 5, etc. positions. unsigned opNum = 1; for (Attribute typeAttr : tdOp.getMembers()) { - refNames.push_back(typeAttr.cast().getValue()); + refNames.push_back(cast(typeAttr).getValue()); operandIndices.push_back(opNum); opNum += 2; } @@ -1299,7 +1298,7 @@ auto llvmModule = std::make_unique(name, llvmContext); if (auto dataLayoutAttr = m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName())) { - llvmModule->setDataLayout(dataLayoutAttr.cast().getValue()); + llvmModule->setDataLayout(cast(dataLayoutAttr).getValue()); } else { FailureOr llvmDataLayout(llvm::DataLayout("")); if (auto iface = dyn_cast(m)) { @@ -1319,7 +1318,7 @@ } if (auto targetTripleAttr = m->getAttr(LLVM::LLVMDialect::getTargetTripleAttrName())) - llvmModule->setTargetTriple(targetTripleAttr.cast().getValue()); + llvmModule->setTargetTriple(cast(targetTripleAttr).getValue()); // Inject declarations for `malloc` and `free` functions that can be used in // memref allocation/deallocation coming from standard ops lowering. diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -364,11 +364,11 @@ } Type fnType = getType(operands[3]); - if (!fnType || !fnType.isa()) { + if (!fnType || !isa(fnType)) { return emitError(unknownLoc, "unknown function type from ") << operands[3]; } - auto functionType = fnType.cast(); + auto functionType = cast(fnType); if ((isVoidType(resultType) && functionType.getNumResults() != 0) || (functionType.getNumResults() == 1 && @@ -562,7 +562,7 @@ return emitError(unknownLoc, "unknown result type : ") << operands[wordIndex]; } - auto ptrType = type.dyn_cast(); + auto ptrType = dyn_cast(type); if (!ptrType) { return emitError(unknownLoc, "expected a result type to be a spirv.ptr, found : ") @@ -623,7 +623,7 @@ if (!constInfo) { return nullptr; } - return constInfo->first.dyn_cast(); + return dyn_cast(constInfo->first); } LogicalResult spirv::Deserializer::processName(ArrayRef operands) { @@ -825,7 +825,7 @@ << operands[2] << "can only come from normal constant right now"; } - if (auto intVal = countInfo->first.dyn_cast()) { + if (auto intVal = dyn_cast(countInfo->first)) { count = intVal.getValue().getZExtValue(); } else { return emitError(unknownLoc, "OpTypeArray count must come from a " @@ -1172,7 +1172,7 @@ auto resultID = operands[1]; - if (auto intType = resultType.dyn_cast()) { + if (auto intType = dyn_cast(resultType)) { auto bitwidth = intType.getWidth(); if (failed(checkOperandSizeForBitwidth(bitwidth))) { return failure(); @@ -1205,7 +1205,7 @@ return success(); } - if (auto floatType = resultType.dyn_cast()) { + if (auto floatType = dyn_cast(resultType)) { auto bitwidth = floatType.getWidth(); if (failed(checkOperandSizeForBitwidth(bitwidth))) { return failure(); @@ -1295,12 +1295,12 @@ } auto resultID = operands[1]; - if (auto vectorType = resultType.dyn_cast()) { + if (auto vectorType = dyn_cast(resultType)) { auto attr = DenseElementsAttr::get(vectorType, elements); // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. constantMap.try_emplace(resultID, attr, resultType); - } else if (auto arrayType = resultType.dyn_cast()) { + } else if (auto arrayType = dyn_cast(resultType)) { auto attr = opBuilder.getArrayAttr(elements); constantMap.try_emplace(resultID, attr, resultType); } else { @@ -1444,7 +1444,7 @@ } auto resultID = operands[1]; - if (resultType.isIntOrFloat() || resultType.isa()) { + if (resultType.isIntOrFloat() || isa(resultType)) { auto attr = opBuilder.getZeroAttr(resultType); // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -98,7 +98,7 @@ auto constituents = op.getConstituents(); for (auto index : llvm::seq(0, constituents.size())) { - auto constituent = constituents[index].dyn_cast(); + auto constituent = dyn_cast(constituents[index]); auto constituentName = constituent.getValue(); auto constituentID = getSpecConstID(constituentName); @@ -280,7 +280,7 @@ auto attr = op->getAttr(spirv::attributeName()); if (attr) { operands.push_back( - static_cast(attr.cast().getValue())); + static_cast(cast(attr).getValue())); } elidedAttrs.push_back(spirv::attributeName()); for (auto arg : op.getODSOperands(0)) { @@ -491,7 +491,7 @@ if (auto weights = condBranchOp.getBranchWeights()) { for (auto val : weights->getValue()) - arguments.push_back(val.cast().getInt()); + arguments.push_back(cast(val).getInt()); } if (failed(emitDebugLine(functionBody, condBranchOp.getLoc()))) @@ -554,7 +554,7 @@ // Add the interface values. if (auto interface = op.getInterface()) { for (auto var : interface.getValue()) { - auto id = getVariableID(var.cast().getValue()); + auto id = getVariableID(cast(var).getValue()); if (!id) { return op.emitError( "referencing undefined global variable." @@ -617,7 +617,7 @@ operands.push_back(valueID); } - if (!resultTy.isa()) + if (!isa(resultTy)) valueIDMap[op.getResult(0)] = funcCallID; encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands); @@ -638,28 +638,28 @@ if (auto attr = op->getAttr("memory_access")) { operands.push_back( - static_cast(attr.cast().getValue())); + static_cast(cast(attr).getValue())); } elidedAttrs.push_back("memory_access"); if (auto attr = op->getAttr("alignment")) { operands.push_back(static_cast( - attr.cast().getValue().getZExtValue())); + cast(attr).getValue().getZExtValue())); } elidedAttrs.push_back("alignment"); if (auto attr = op->getAttr("source_memory_access")) { operands.push_back( - static_cast(attr.cast().getValue())); + static_cast(cast(attr).getValue())); } elidedAttrs.push_back("source_memory_access"); if (auto attr = op->getAttr("source_alignment")) { operands.push_back(static_cast( - attr.cast().getValue().getZExtValue())); + cast(attr).getValue().getZExtValue())); } elidedAttrs.push_back("source_alignment"); @@ -689,7 +689,7 @@ for (Value operand : op->getOperands()) operands.push_back(getValueID(operand)); spirv::StorageClass resultStorage = - resultTy.cast().getStorageClass(); + cast(resultTy).getStorageClass(); operands.push_back(static_cast(resultStorage)); encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit, operands); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -144,7 +144,7 @@ << "id = " << valueIDPair.second << ' '; if (auto *op = val.getDefiningOp()) { os << "from op '" << op->getName() << "'"; - } else if (auto arg = val.dyn_cast()) { + } else if (auto arg = dyn_cast(val)) { Block *block = arg.getOwner(); os << "from argument of block " << block << ' '; os << " in op '" << block->getParentOp()->getName() << "'"; @@ -176,7 +176,7 @@ void Serializer::processDebugInfo() { if (!options.emitDebugInfo) return; - auto fileLoc = module.getLoc().dyn_cast(); + auto fileLoc = dyn_cast(module.getLoc()); auto fileName = fileLoc ? fileLoc.getFilename().strref() : ""; fileID = getNextID(); SmallVector operands; @@ -221,13 +221,13 @@ case spirv::Decoration::Binding: case spirv::Decoration::DescriptorSet: case spirv::Decoration::Location: - if (auto intAttr = attr.getValue().dyn_cast()) { + if (auto intAttr = dyn_cast(attr.getValue())) { args.push_back(intAttr.getValue().getZExtValue()); break; } return emitError(loc, "expected integer attribute for ") << attrName; case spirv::Decoration::BuiltIn: - if (auto strAttr = attr.getValue().dyn_cast()) { + if (auto strAttr = dyn_cast(attr.getValue())) { auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); if (enumVal) { args.push_back(static_cast(*enumVal)); @@ -245,7 +245,7 @@ case spirv::Decoration::Restrict: case spirv::Decoration::RelaxedPrecision: // For unit attributes, the args list has no values so we do nothing - if (auto unitAttr = attr.getValue().dyn_cast()) + if (auto unitAttr = dyn_cast(attr.getValue())) break; return emitError(loc, "expected unit attribute for ") << attrName; default: @@ -307,13 +307,13 @@ // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and // PushConstant Storage Classes must be explicitly laid out." bool Serializer::isInterfaceStructPtrType(Type type) const { - if (auto ptrType = type.dyn_cast()) { + if (auto ptrType = dyn_cast(type)) { switch (ptrType.getStorageClass()) { case spirv::StorageClass::PhysicalStorageBuffer: case spirv::StorageClass::PushConstant: case spirv::StorageClass::StorageBuffer: case spirv::StorageClass::Uniform: - return ptrType.getPointeeType().isa(); + return isa(ptrType.getPointeeType()); default: break; } @@ -343,8 +343,8 @@ auto typeEnum = spirv::Opcode::OpTypeVoid; bool deferSerialization = false; - if ((type.isa() && - succeeded(prepareFunctionType(loc, type.cast(), typeEnum, + if ((isa(type) && + succeeded(prepareFunctionType(loc, cast(type), typeEnum, operands))) || succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands, deferSerialization, serializationCtx))) { @@ -390,7 +390,7 @@ return success(); } - if (auto intType = type.dyn_cast()) { + if (auto intType = dyn_cast(type)) { if (intType.getWidth() == 1) { typeEnum = spirv::Opcode::OpTypeBool; return success(); @@ -406,13 +406,13 @@ return success(); } - if (auto floatType = type.dyn_cast()) { + if (auto floatType = dyn_cast(type)) { typeEnum = spirv::Opcode::OpTypeFloat; operands.push_back(floatType.getWidth()); return success(); } - if (auto vectorType = type.dyn_cast()) { + if (auto vectorType = dyn_cast(type)) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID, serializationCtx))) { @@ -424,7 +424,7 @@ return success(); } - if (auto imageType = type.dyn_cast()) { + if (auto imageType = dyn_cast(type)) { typeEnum = spirv::Opcode::OpTypeImage; uint32_t sampledTypeID = 0; if (failed(processType(loc, imageType.getElementType(), sampledTypeID))) @@ -440,7 +440,7 @@ return success(); } - if (auto arrayType = type.dyn_cast()) { + if (auto arrayType = dyn_cast(type)) { typeEnum = spirv::Opcode::OpTypeArray; uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID, @@ -455,10 +455,10 @@ return processTypeDecoration(loc, arrayType, resultID); } - if (auto ptrType = type.dyn_cast()) { + if (auto ptrType = dyn_cast(type)) { uint32_t pointeeTypeID = 0; spirv::StructType pointeeStruct = - ptrType.getPointeeType().dyn_cast(); + dyn_cast(ptrType.getPointeeType()); if (pointeeStruct && pointeeStruct.isIdentified() && serializationCtx.count(pointeeStruct.getIdentifier()) != 0) { @@ -510,7 +510,7 @@ return success(); } - if (auto runtimeArrayType = type.dyn_cast()) { + if (auto runtimeArrayType = dyn_cast(type)) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(), elementTypeID, serializationCtx))) { @@ -521,7 +521,7 @@ return processTypeDecoration(loc, runtimeArrayType, resultID); } - if (auto sampledImageType = type.dyn_cast()) { + if (auto sampledImageType = dyn_cast(type)) { typeEnum = spirv::Opcode::OpTypeSampledImage; uint32_t imageTypeID = 0; if (failed( @@ -532,7 +532,7 @@ return success(); } - if (auto structType = type.dyn_cast()) { + if (auto structType = dyn_cast(type)) { if (structType.isIdentified()) { if (failed(processName(resultID, structType.getIdentifier()))) return failure(); @@ -581,7 +581,7 @@ } if (auto cooperativeMatrixType = - type.dyn_cast()) { + dyn_cast(type)) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(), elementTypeID, serializationCtx))) { @@ -600,7 +600,7 @@ return success(); } - if (auto jointMatrixType = type.dyn_cast()) { + if (auto jointMatrixType = dyn_cast(type)) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, jointMatrixType.getElementType(), elementTypeID, serializationCtx))) { @@ -621,7 +621,7 @@ return success(); } - if (auto matrixType = type.dyn_cast()) { + if (auto matrixType = dyn_cast(type)) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID, serializationCtx))) { @@ -684,12 +684,12 @@ } uint32_t resultID = 0; - if (auto attr = valueAttr.dyn_cast()) { - int rank = attr.getType().dyn_cast().getRank(); + if (auto attr = dyn_cast(valueAttr)) { + int rank = dyn_cast(attr.getType()).getRank(); SmallVector index(rank); resultID = prepareDenseElementsConstant(loc, constType, attr, /*dim=*/0, index); - } else if (auto arrayAttr = valueAttr.dyn_cast()) { + } else if (auto arrayAttr = dyn_cast(valueAttr)) { resultID = prepareArrayConstant(loc, constType, arrayAttr); } @@ -712,7 +712,7 @@ uint32_t resultID = getNextID(); SmallVector operands = {typeID, resultID}; operands.reserve(attr.size() + 2); - auto elementType = constType.cast().getElementType(); + auto elementType = cast(constType).getElementType(); for (Attribute elementAttr : attr) { if (auto elementID = prepareConstant(loc, elementType, elementAttr)) { operands.push_back(elementID); @@ -732,16 +732,16 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType, DenseElementsAttr valueAttr, int dim, MutableArrayRef index) { - auto shapedType = valueAttr.getType().dyn_cast(); + auto shapedType = dyn_cast(valueAttr.getType()); assert(dim <= shapedType.getRank()); if (shapedType.getRank() == dim) { - if (auto attr = valueAttr.dyn_cast()) { + if (auto attr = dyn_cast(valueAttr)) { return attr.getType().getElementType().isInteger(1) ? prepareConstantBool(loc, attr.getValues()[index]) : prepareConstantInt(loc, attr.getValues()[index]); } - if (auto attr = valueAttr.dyn_cast()) { + if (auto attr = dyn_cast(valueAttr)) { return prepareConstantFp(loc, attr.getValues()[index]); } return 0; @@ -755,7 +755,7 @@ uint32_t resultID = getNextID(); SmallVector operands = {typeID, resultID}; operands.reserve(shapedType.getDimSize(dim) + 2); - auto elementType = constType.cast().getElementType(0); + auto elementType = cast(constType).getElementType(0); for (int i = 0; i < shapedType.getDimSize(dim); ++i) { index[dim] = i; if (auto elementID = prepareDenseElementsConstant( @@ -773,13 +773,13 @@ uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr, bool isSpec) { - if (auto floatAttr = valueAttr.dyn_cast()) { + if (auto floatAttr = dyn_cast(valueAttr)) { return prepareConstantFp(loc, floatAttr, isSpec); } - if (auto boolAttr = valueAttr.dyn_cast()) { + if (auto boolAttr = dyn_cast(valueAttr)) { return prepareConstantBool(loc, boolAttr, isSpec); } - if (auto intAttr = valueAttr.dyn_cast()) { + if (auto intAttr = dyn_cast(valueAttr)) { return prepareConstantInt(loc, intAttr, isSpec); } @@ -797,8 +797,7 @@ // Process the type for this bool literal uint32_t typeID = 0; - if (failed( - processType(loc, boolAttr.cast().getType(), typeID))) { + if (failed(processType(loc, cast(boolAttr).getType(), typeID))) { return 0; } @@ -1246,7 +1245,7 @@ return success(); } - auto fileLoc = loc.dyn_cast(); + auto fileLoc = dyn_cast(loc); if (fileLoc) encodeInstructionInto(binary, spirv::Opcode::OpLine, {fileID, fileLoc.getLine(), fileLoc.getColumn()}); diff --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp --- a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp @@ -239,7 +239,7 @@ // replacement values. bool usesReplOperation = replValues.size() == 1 && - replValues.front().getType().isa(); + isa(replValues.front().getType()); builder.create( loc, rootExpr, usesReplOperation ? replValues[0] : Value(), usesReplOperation ? ValueRange() : ValueRange(replValues)); @@ -441,7 +441,7 @@ if (ast::OperationType opType = parentType.dyn_cast()) { if (isa(expr)) { Type mlirType = genType(expr->getType()); - if (mlirType.isa()) + if (isa(mlirType)) return builder.create(loc, mlirType, parentExprs[0], builder.getI32IntegerAttr(0)); return builder.create(loc, mlirType, parentExprs[0]); diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp @@ -58,7 +58,7 @@ StringRef uriScheme, const lsp::URIForFile *uri = nullptr) { std::optional location; loc->walk([&](Location nestedLoc) { - FileLineColLoc fileLoc = nestedLoc.dyn_cast(); + FileLineColLoc fileLoc = dyn_cast(nestedLoc); if (!fileLoc) return WalkResult::advance(); @@ -91,7 +91,7 @@ const lsp::URIForFile &uri) { SetVector visitedLocs; loc->walk([&](Location nestedLoc) { - FileLineColLoc fileLoc = nestedLoc.dyn_cast(); + FileLineColLoc fileLoc = dyn_cast(nestedLoc); if (!fileLoc || !visitedLocs.insert(nestedLoc)) return WalkResult::advance(); diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -136,7 +136,7 @@ // If the existing operation has an unknown location and the current // operation doesn't, then set the existing op's location to that of the // current op. - if (existing->getLoc().isa() && !op->getLoc().isa()) + if (isa(existing->getLoc()) && !isa(op->getLoc())) existing->setLoc(op->getLoc()); ++numCSE; diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -345,7 +345,7 @@ // TODO: Support inlining nested call references. CallInterfaceCallable callable = call.getCallableForCallee(); if (SymbolRefAttr symRef = dyn_cast(callable)) { - if (!symRef.isa()) + if (!isa(symRef)) continue; } diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp --- a/mlir/lib/Transforms/Mem2Reg.cpp +++ b/mlir/lib/Transforms/Mem2Reg.cpp @@ -180,7 +180,7 @@ : slot(slot), allocator(allocator), builder(builder), dominance(dominance) { #ifndef NDEBUG auto isResultOrNewBlockArgument = [&]() { - if (BlockArgument arg = slot.ptr.dyn_cast()) + if (BlockArgument arg = dyn_cast(slot.ptr)) return arg.getOwner()->getParentOp() == allocator; return slot.ptr.getDefiningOp() == allocator; }; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -401,7 +401,7 @@ SmallVectorImpl &unresolvedMaterializations) { Block *insertBlock = input.getParentBlock(); Block::iterator insertPt = insertBlock->begin(); - if (OpResult inputRes = input.dyn_cast()) + if (OpResult inputRes = dyn_cast(input)) insertPt = ++inputRes.getOwner()->getIterator(); return buildUnresolvedMaterialization( @@ -1033,7 +1033,7 @@ if (!repl) continue; - if (repl.isa()) { + if (isa(repl)) { arg.replaceAllUsesWith(repl); continue; } @@ -1041,7 +1041,7 @@ // If the replacement value is an operation, we check to make sure that we // don't replace uses that are within the parent operation of the // replacement value. - Operation *replOp = repl.cast().getOwner(); + Operation *replOp = cast(repl).getOwner(); Block *replBlock = replOp->getBlock(); arg.replaceUsesWithIf(repl, [&](OpOperand &operand) { Operation *user = operand.getOwner(); @@ -2615,7 +2615,7 @@ } // Check to see if this is an argument materialization. - auto isBlockArg = [](Value v) { return v.isa(); }; + auto isBlockArg = [](Value v) { return isa(v); }; if (llvm::any_of(op->getOperands(), isBlockArg) || llvm::any_of(inverseMapping[op->getResult(0)], isBlockArg)) { mat->setKind(UnresolvedMaterialization::Argument); diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp --- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp +++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp @@ -384,7 +384,7 @@ assert(castKind == getCastKindName(CastKind::Argument) && "unexpected value of cast kind attribute"); assert(llvm::all_of(operands, - [&](Value v) { return v.isa(); })); + [&](Value v) { return isa(v); })); maybeResult = typeConverter.materializeArgumentConversion( rewriter, castOp->getLoc(), resultTypes.front(), castOp.getOperands()); diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -244,17 +244,17 @@ bool wasProvenLive(Value value) { // TODO: For results that are removable, e.g. for region based control flow, // we could allow for these values to be tracked independently. - if (OpResult result = value.dyn_cast()) + if (OpResult result = dyn_cast(value)) return wasProvenLive(result.getOwner()); - return wasProvenLive(value.cast()); + return wasProvenLive(cast(value)); } bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); } void setProvedLive(Value value) { // TODO: For results that are removable, e.g. for region based control flow, // we could allow for these values to be tracked independently. - if (OpResult result = value.dyn_cast()) + if (OpResult result = dyn_cast(value)) return setProvedLive(result.getOwner()); - setProvedLive(value.cast()); + setProvedLive(cast(value)); } void setProvedLive(BlockArgument arg) { changed |= liveValues.insert(arg).second; @@ -538,11 +538,11 @@ assert(value.getParentBlock() == block && "expected value of this block"); // Arguments use the argument number as the order index. - if (BlockArgument arg = value.dyn_cast()) + if (BlockArgument arg = dyn_cast(value)) return arg.getArgNumber(); // Otherwise, the result order is offset from the parent op's order. - OpResult result = value.cast(); + OpResult result = cast(value); auto opOrderIt = opOrderIndex.find(result.getDefiningOp()); assert(opOrderIt != opOrderIndex.end() && "expected op to have an order"); return opOrderIt->second + result.getResultNumber(); diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -145,13 +145,13 @@ int64_t largeAttrLimit = getLargeAttributeSizeLimit(); // Always emit splat attributes. - if (attr.isa()) { + if (isa(attr)) { attr.print(os); return; } // Elide "big" elements attributes. - auto elements = attr.dyn_cast(); + auto elements = dyn_cast(attr); if (elements && elements.getNumElements() > largeAttrLimit) { os << std::string(elements.getShapedType().getRank(), '[') << "..." << std::string(elements.getShapedType().getRank(), ']') << " : " @@ -159,7 +159,7 @@ return; } - auto array = attr.dyn_cast(); + auto array = dyn_cast(attr); if (array && static_cast(array.size()) > largeAttrLimit) { os << "[...]"; return; diff --git a/mlir/test/lib/Analysis/TestAliasAnalysis.cpp b/mlir/test/lib/Analysis/TestAliasAnalysis.cpp --- a/mlir/test/lib/Analysis/TestAliasAnalysis.cpp +++ b/mlir/test/lib/Analysis/TestAliasAnalysis.cpp @@ -24,7 +24,7 @@ llvm::errs() << op->getAttrOfType("test.ptr").getValue(); } static void printAliasOperand(Value value) { - if (BlockArgument arg = value.dyn_cast()) { + if (BlockArgument arg = dyn_cast(value)) { Region *region = arg.getParentRegion(); unsigned parentBlockNumber = std::distance(region->begin(), arg.getOwner()->getIterator()); @@ -37,7 +37,7 @@ llvm::errs() << "#" << arg.getArgNumber(); return; } - OpResult result = value.cast(); + OpResult result = cast(value); printAliasOperand(result.getOwner()); llvm::errs() << "#" << result.getResultNumber(); } @@ -156,7 +156,7 @@ /// Check if value is function argument. static bool isFuncArg(Value val) { - auto blockArg = val.dyn_cast(); + auto blockArg = dyn_cast(val); if (!blockArg) return false; @@ -166,7 +166,7 @@ /// Check if value has "restrict" attribute. Value must be a function argument. static bool isRestrict(Value val) { - auto blockArg = val.cast(); + auto blockArg = cast(val); auto func = mlir::cast(blockArg.getOwner()->getParentOp()); return !!func.getArgAttr(blockArg.getArgNumber(), diff --git a/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp b/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp --- a/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp +++ b/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp @@ -32,7 +32,7 @@ void TestMemRefStrideCalculation::runOnOperation() { llvm::outs() << "Testing: " << getOperation().getName() << "\n"; getOperation().walk([&](memref::AllocOp allocOp) { - auto memrefType = allocOp.getResult().getType().cast(); + auto memrefType = cast(allocOp.getResult().getType()); int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(memrefType, strides, offset))) { diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp --- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp @@ -102,7 +102,7 @@ matchAndRewrite(::test::GetTupleElementOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { // Construct mapping for tuple element types. - auto stateType = op->getOperand(0).getType().cast(); + auto stateType = cast(op->getOperand(0).getType()); TypeRange originalElementTypes = stateType.getTypes(); OneToNTypeMapping elementMapping(originalElementTypes); if (failed(typeConverter->convertSignatureArgs(originalElementTypes, @@ -148,7 +148,7 @@ static std::optional> buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input, Location loc) { - TupleType inputType = input.getType().dyn_cast(); + TupleType inputType = dyn_cast(input.getType()); if (!inputType) return {}; @@ -156,7 +156,7 @@ for (auto [idx, elementType] : llvm::enumerate(inputType.getTypes())) { Value element = builder.create<::test::GetTupleElementOp>( loc, elementType, input, builder.getI32IntegerAttr(idx)); - if (auto nestedTupleType = elementType.dyn_cast()) { + if (auto nestedTupleType = dyn_cast(elementType)) { // Recurse if the current element is also a tuple. SmallVector flatRecursiveTypes; nestedTupleType.getFlattenedTypes(flatRecursiveTypes); @@ -186,7 +186,7 @@ elements.reserve(resultType.getTypes().size()); ValueRange::iterator inputIt = inputs.begin(); for (Type elementType : resultType.getTypes()) { - if (auto nestedTupleType = elementType.dyn_cast()) { + if (auto nestedTupleType = dyn_cast(elementType)) { // Determine how many input values are needed for the nested elements of // the nested TupleType and advance inputIt by that number. // TODO: We only need the *number* of nested types, not the types itself. diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -81,7 +81,7 @@ return WalkResult::skip(); } Value value = op->getOperand(0); - if (value.getType().isa() != + if (isa(value.getType()) != !op->hasAttrOfType("dim")) { // Op should have "dim" attribute if and only if the operand is an // index-typed value. @@ -119,7 +119,7 @@ if (reifyToFuncArgs) { // Reify in terms of function block arguments. stopCondition = stopCondition = [](Value v, std::optional d) { - auto bbArg = v.dyn_cast(); + auto bbArg = dyn_cast(v); if (!bbArg) return false; return isa( @@ -166,7 +166,7 @@ return WalkResult::skip(); } Value constOp = rewriter.create( - op->getLoc(), reified->get().cast().getInt()); + op->getLoc(), cast(reified->get()).getInt()); rewriter.replaceOp(op, constOp); return WalkResult::skip(); } diff --git a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp --- a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp +++ b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp @@ -127,7 +127,7 @@ // As a consequence we write only Ops with a single return type for the // purpose of this test. If we need to test more intricate behavior in the // future we can always extend. - auto superVectorType = opInst->getResult(0).getType().cast(); + auto superVectorType = cast(opInst->getResult(0).getType()); auto ratio = computeShapeRatio(superVectorType.getShape(), subVectorType.getShape()); if (!ratio) { @@ -211,8 +211,8 @@ maps.reserve(matches.size()); for (auto m : llvm::reverse(matches)) { auto *opInst = m.getMatchedOperation(); - auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName) - .cast() + auto map = cast( + opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName)) .getValue(); maps.push_back(map); } diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp --- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp +++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp @@ -27,7 +27,7 @@ Type elementType = resultType.getType(i); Value element = builder.create( loc, elementType, value, builder.getI32IntegerAttr(i)); - if (auto nestedTupleType = elementType.dyn_cast()) { + if (auto nestedTupleType = dyn_cast(elementType)) { // Recurse if the current element is also a tuple. if (failed(buildDecomposeTuple(builder, loc, nestedTupleType, element, values))) @@ -50,7 +50,7 @@ elements.reserve(resultType.getTypes().size()); ValueRange::iterator inputIt = inputs.begin(); for (Type elementType : resultType.getTypes()) { - if (auto nestedTupleType = elementType.dyn_cast()) { + if (auto nestedTupleType = dyn_cast(elementType)) { // Determine how many input values are needed for the nested elements of // the nested TupleType and advance inputIt by that number. // TODO: We only need the *number* of nested types, not the types itself. diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -38,9 +38,9 @@ bool changed = false; for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { for (OpOperand &opOperand : linalgOp->getOpOperands()) { - if (opOperand.get().getType().isa()) + if (isa(opOperand.get().getType())) continue; - if (opOperand.get().getType().isa()) { + if (isa(opOperand.get().getType())) { // Tile and Fuse tensor input. if (opOperand.getOperandNumber() >= linalgOp.getNumDpsInputs()) continue; diff --git a/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp --- a/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp +++ b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp @@ -61,9 +61,9 @@ if (attr) { auto lookup = [&](Attribute attr) { return cast( - SymbolTable::lookupSymbolIn(module, attr.cast())); + SymbolTable::lookupSymbolIn(module, cast(attr))); }; - if (auto arrayAttr = attr.dyn_cast()) { + if (auto arrayAttr = dyn_cast(attr)) { libraries.reserve(arrayAttr.size()); for (auto attr : arrayAttr) libraries.push_back(lookup(attr)); diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -113,7 +113,7 @@ if (!op.getSource().hasOneUse()) return false; - auto resultType = op.getResult().getType().cast(); + auto resultType = cast(op.getResult().getType()); constexpr int64_t kConstantFoldingMaxNumElements = 1024; return resultType.getNumElements() <= kConstantFoldingMaxNumElements; }; diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -49,7 +49,7 @@ } LogicalResult MyPropStruct::setFromAttr(MyPropStruct &prop, Attribute attr, InFlightDiagnostic *diag) { - StringAttr strAttr = attr.dyn_cast(); + StringAttr strAttr = dyn_cast(attr); if (!strAttr) { if (diag) *diag << "Expect StringAttr but got " << attr; @@ -222,7 +222,7 @@ //===------------------------------------------------------------------===// AliasResult getAlias(Attribute attr, raw_ostream &os) const final { - StringAttr strAttr = attr.dyn_cast(); + StringAttr strAttr = dyn_cast(attr); if (!strAttr) return AliasResult::NoAlias; @@ -247,16 +247,16 @@ } AliasResult getAlias(Type type, raw_ostream &os) const final { - if (auto tupleType = type.dyn_cast()) { + if (auto tupleType = dyn_cast(type)) { if (tupleType.size() > 0 && llvm::all_of(tupleType.getTypes(), [](Type elemType) { - return elemType.isa(); + return isa(elemType); })) { os << "test_tuple"; return AliasResult::FinalAlias; } } - if (auto intType = type.dyn_cast()) { + if (auto intType = dyn_cast(type)) { if (intType.getSignedness() == TestIntegerType::SignednessSemantics::Unsigned && intType.getWidth() == 8) { @@ -264,7 +264,7 @@ return AliasResult::FinalAlias; } } - if (auto recType = type.dyn_cast()) { + if (auto recType = dyn_cast(type)) { if (recType.getName() == "type_to_alias") { // We only make alias for a specific recursive type. os << "testrec"; @@ -1231,7 +1231,7 @@ auto args = getRegion().front().getArguments(); auto e = std::min(arrayAttr.size(), args.size()); for (unsigned i = 0; i < e; ++i) { - if (auto strAttr = arrayAttr[i].dyn_cast()) + if (auto strAttr = dyn_cast(arrayAttr[i])) setNameFn(args[i], strAttr.getValue()); } } @@ -1253,7 +1253,7 @@ } static void printOptionalLoc(OpAsmPrinter &p, Operation *op, Attribute loc) { - p.printOptionalLocationSpecifier(loc.cast()); + p.printOptionalLocationSpecifier(cast(loc)); } //===----------------------------------------------------------------------===// @@ -1377,7 +1377,7 @@ SmallVectorImpl &inferredReturnShapes) { // Create return type consisting of the last element of the first operand. auto operandType = operands.front().getType(); - auto sval = operandType.dyn_cast(); + auto sval = dyn_cast(operandType); if (!sval) { return emitOptionalError(location, "only shaped type operands allowed"); } @@ -1385,7 +1385,7 @@ auto type = IntegerType::get(context, 17); Attribute encoding; - if (auto rankedTy = sval.dyn_cast()) + if (auto rankedTy = dyn_cast(sval)) encoding = rankedTy.getEncoding(); inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding)); return success(); @@ -1405,7 +1405,7 @@ Location loc = getLoc(); shapes.reserve(operands.size()); for (Value operand : llvm::reverse(operands)) { - auto rank = operand.getType().cast().getRank(); + auto rank = cast(operand.getType()).getRank(); auto currShape = llvm::to_vector<4>( llvm::map_range(llvm::seq(0, rank), [&](int64_t dim) -> Value { return builder.createOrFold(loc, operand, dim); @@ -1422,7 +1422,7 @@ Location loc = getLoc(); shapes.reserve(getNumOperands()); for (Value operand : llvm::reverse(getOperands())) { - auto tensorType = operand.getType().cast(); + auto tensorType = cast(operand.getType()); auto currShape = llvm::to_vector<4>(llvm::map_range( llvm::seq(0, tensorType.getRank()), [&](int64_t dim) -> OpFoldResult { @@ -1472,12 +1472,12 @@ // If there is one, it is an array of dictionary attributes that hold // information on the effects of this operation. for (Attribute element : effectsAttr) { - DictionaryAttr effectElement = element.cast(); + DictionaryAttr effectElement = cast(element); // Get the specific memory effect. MemoryEffects::Effect *effect = StringSwitch( - effectElement.get("effect").cast().getValue()) + cast(effectElement.get("effect")).getValue()) .Case("allocate", MemoryEffects::Allocate::get()) .Case("free", MemoryEffects::Free::get()) .Case("read", MemoryEffects::Read::get()) @@ -1492,7 +1492,7 @@ if (effectElement.get("on_result")) effects.emplace_back(effect, getResult(), resource); else if (Attribute ref = effectElement.get("on_reference")) - effects.emplace_back(effect, ref.cast(), resource); + effects.emplace_back(effect, cast(ref), resource); else effects.emplace_back(effect, resource); } @@ -1557,7 +1557,7 @@ llvm::raw_svector_ostream tmpStream(resultNameStr); p.printOperand(getResult(i), tmpStream); - auto expectedName = getNames()[i].dyn_cast(); + auto expectedName = dyn_cast(getNames()[i]); if (!expectedName || tmpStream.str().drop_front() != expectedName.getValue()) { namesDisagree = true; @@ -1577,7 +1577,7 @@ auto value = getNames(); for (size_t i = 0, e = value.size(); i != e; ++i) - if (auto str = value[i].dyn_cast()) + if (auto str = dyn_cast(value[i])) if (!str.getValue().empty()) setNameFn(getResult(i), str.getValue()); } @@ -1586,7 +1586,7 @@ function_ref setNameFn) { ArrayAttr value = getNames(); for (size_t i = 0, e = value.size(); i != e; ++i) - if (auto str = value[i].dyn_cast()) + if (auto str = dyn_cast(value[i])) if (!str.getValue().empty()) setNameFn(getResult(i), str.getValue()); } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -153,7 +153,7 @@ LogicalResult matchAndRewrite(AnyAttrOfOp op, PatternRewriter &rewriter) const override { - auto intAttr = op.getAttr().dyn_cast(); + auto intAttr = dyn_cast(op.getAttr()); if (!intAttr) return failure(); int64_t val = intAttr.getInt(); @@ -1271,11 +1271,11 @@ Type convertedType = getTypeConverter() ? getTypeConverter()->convertType(resultType) : resultType; - if (resultType.isa()) + if (isa(resultType)) resultType = rewriter.getF64Type(); else if (resultType.isInteger(16)) resultType = rewriter.getIntegerType(64); - else if (resultType.isa() && + else if (isa(resultType) && convertedType != resultType) resultType = convertedType; else @@ -1430,8 +1430,8 @@ inputs.empty()) return builder.create(loc, resultType); // Allow producing an i64 from an integer. - if (resultType.isa() && inputs.size() == 1 && - inputs[0].getType().isa()) + if (isa(resultType) && inputs.size() == 1 && + isa(inputs[0].getType())) return builder.create(loc, resultType, inputs).getResult(); // Otherwise, fail. return nullptr; @@ -1440,7 +1440,7 @@ // Initialize the conversion target. mlir::ConversionTarget target(getContext()); target.addDynamicallyLegalOp([](TestTypeProducerOp op) { - auto recursiveType = op.getType().dyn_cast(); + auto recursiveType = dyn_cast(op.getType()); return op.getType().isF64() || op.getType().isInteger(64) || (recursiveType && recursiveType.getName() == "outer_converted_type"); diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp --- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp +++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp @@ -42,20 +42,20 @@ auto tosaNegateOp = cast(op); auto inputType = - tosaNegateOp.getInput1().getType().dyn_cast(); + dyn_cast(tosaNegateOp.getInput1().getType()); // skip if input is not ranked tensor type if (!inputType) return failure(); // skip if it's not ranked tensor type. auto outputType = - tosaNegateOp.getResult().getType().dyn_cast(); + dyn_cast(tosaNegateOp.getResult().getType()); if (!outputType) return failure(); // skip if output is not per-tensor quantized type. auto outputElementType = - outputType.getElementType().dyn_cast(); + dyn_cast(outputType.getElementType()); if (!outputElementType) return failure(); @@ -112,14 +112,14 @@ auto tosaConv2DOp = cast(op); auto inputType = - tosaConv2DOp.getInput().getType().dyn_cast(); + dyn_cast(tosaConv2DOp.getInput().getType()); // skip if input is not ranked tensor type if (!inputType) return failure(); auto weightType = - tosaConv2DOp.getWeight().getType().dyn_cast(); + dyn_cast(tosaConv2DOp.getWeight().getType()); // skip if wt is not ranked tensor type if (!weightType) @@ -127,16 +127,16 @@ // skip if it's not ranked tensor type. auto outputType = - tosaConv2DOp.getResult().getType().dyn_cast(); + dyn_cast(tosaConv2DOp.getResult().getType()); if (!outputType) return failure(); auto inputQType = - inputType.getElementType().dyn_cast(); + dyn_cast(inputType.getElementType()); auto weightQType = - weightType.getElementType().dyn_cast(); + dyn_cast(weightType.getElementType()); auto outputQType = - outputType.getElementType().dyn_cast(); + dyn_cast(outputType.getElementType()); // Works on quantized type only. if (!(inputQType && weightQType && outputQType)) diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -89,7 +89,7 @@ auto extract = dyn_cast(users); if (!extract) return std::nullopt; - auto vecType = extract.getResult().getType().cast(); + auto vecType = cast(extract.getResult().getType()); if (dstVec && dstVec != vecType) return std::nullopt; dstVec = vecType; @@ -430,7 +430,7 @@ static constexpr int64_t kSharedMemorySpace = 3; // Compute type of shared memory buffer. MemRefType memrefType; - if (auto vectorType = type.dyn_cast()) { + if (auto vectorType = dyn_cast(type)) { memrefType = MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {}, kSharedMemorySpace); @@ -535,7 +535,7 @@ // Create a map (d0, d1) -> (d1) to distribute along the inner // dimension. Once we support n-d distribution we can add more // complex cases. - VectorType vecType = val.getType().dyn_cast(); + VectorType vecType = dyn_cast(val.getType()); int64_t vecRank = vecType ? vecType.getRank() : 0; OpBuilder builder(val.getContext()); if (vecRank == 0) @@ -642,9 +642,9 @@ if (op->getName().getStringRef() != "test_create_broadcast") return; auto targetShape = - op->getResult(0).getType().cast().getShape(); + cast(op->getResult(0).getType()).getShape(); auto arrayAttr = - op->getAttr("broadcast_dims").cast().asArrayRef(); + cast(op->getAttr("broadcast_dims")).asArrayRef(); llvm::SetVector broadcastedDims; broadcastedDims.insert(arrayAttr.begin(), arrayAttr.end()); OpBuilder b(op); diff --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp --- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp +++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp @@ -34,7 +34,7 @@ void runOnOperation() override { getOperation().walk([&](Operation *op) { for (NamedAttribute attr : op->getAttrs()) { - auto elementsAttr = attr.getValue().dyn_cast(); + auto elementsAttr = dyn_cast(attr.getValue()); if (!elementsAttr) continue; testElementsAttrIteration(op, elementsAttr, "int64_t"); diff --git a/mlir/test/lib/IR/TestDiagnostics.cpp b/mlir/test/lib/IR/TestDiagnostics.cpp --- a/mlir/test/lib/IR/TestDiagnostics.cpp +++ b/mlir/test/lib/IR/TestDiagnostics.cpp @@ -36,7 +36,7 @@ // Build a diagnostic handler that has filtering capabilities. auto filterFn = [&](Location loc) { // Ignore non-file locations. - FileLineColLoc fileLoc = loc.dyn_cast(); + FileLineColLoc fileLoc = dyn_cast(loc); if (!fileLoc) return true; diff --git a/mlir/test/lib/IR/TestFunc.cpp b/mlir/test/lib/IR/TestFunc.cpp --- a/mlir/test/lib/IR/TestFunc.cpp +++ b/mlir/test/lib/IR/TestFunc.cpp @@ -35,13 +35,13 @@ SmallVector locsToInsert; for (auto insert : inserts.getAsRange()) { indicesToInsert.push_back( - insert[0].cast().getValue().getZExtValue()); - typesToInsert.push_back(insert[1].cast().getValue()); + cast(insert[0]).getValue().getZExtValue()); + typesToInsert.push_back(cast(insert[1]).getValue()); attrsToInsert.push_back(insert.size() > 2 - ? insert[2].cast() + ? cast(insert[2]) : DictionaryAttr::get(&getContext())); locsToInsert.push_back(insert.size() > 3 - ? Location(insert[3].cast()) + ? Location(cast(insert[3])) : unknownLoc); } func->removeAttr("test.insert_args"); @@ -72,10 +72,10 @@ SmallVector attrsToInsert; for (auto insert : inserts.getAsRange()) { indicesToInsert.push_back( - insert[0].cast().getValue().getZExtValue()); - typesToInsert.push_back(insert[1].cast().getValue()); + cast(insert[0]).getValue().getZExtValue()); + typesToInsert.push_back(cast(insert[1]).getValue()); attrsToInsert.push_back(insert.size() > 2 - ? insert[2].cast() + ? cast(insert[2]) : DictionaryAttr::get(&getContext())); } func->removeAttr("test.insert_results"); diff --git a/mlir/test/lib/IR/TestInterfaces.cpp b/mlir/test/lib/IR/TestInterfaces.cpp --- a/mlir/test/lib/IR/TestInterfaces.cpp +++ b/mlir/test/lib/IR/TestInterfaces.cpp @@ -27,7 +27,7 @@ void runOnOperation() override { getOperation().walk([](Operation *op) { for (Type type : op->getResultTypes()) { - if (auto testInterface = type.dyn_cast()) { + if (auto testInterface = dyn_cast(type)) { testInterface.printTypeA(op->getLoc()); testInterface.printTypeB(op->getLoc()); testInterface.printTypeC(op->getLoc()); @@ -37,7 +37,7 @@ TestTypeInterface result = testInterface.printTypeRet(op->getLoc()); (void)result; } - if (auto testType = type.dyn_cast()) + if (auto testType = dyn_cast(type)) testType.printTypeE(op->getLoc()); } }); diff --git a/mlir/test/lib/IR/TestOpaqueLoc.cpp b/mlir/test/lib/IR/TestOpaqueLoc.cpp --- a/mlir/test/lib/IR/TestOpaqueLoc.cpp +++ b/mlir/test/lib/IR/TestOpaqueLoc.cpp @@ -74,7 +74,7 @@ ScopedDiagnosticHandler diagHandler(&getContext(), [](Diagnostic &diag) { auto &os = llvm::outs(); - if (diag.getLocation().isa()) { + if (isa(diag.getLocation())) { MyLocation *loc = OpaqueLoc::getUnderlyingLocationOrNull( diag.getLocation()); if (loc) diff --git a/mlir/test/lib/IR/TestPrintDefUse.cpp b/mlir/test/lib/IR/TestPrintDefUse.cpp --- a/mlir/test/lib/IR/TestPrintDefUse.cpp +++ b/mlir/test/lib/IR/TestPrintDefUse.cpp @@ -34,7 +34,7 @@ } else { // If there is no defining op, the Value is necessarily a Block // argument. - auto blockArg = operand.cast(); + auto blockArg = cast(operand); llvm::outs() << " - Operand produced by Block argument, number " << blockArg.getArgNumber() << "\n"; } diff --git a/mlir/test/lib/Transforms/TestTopologicalSort.cpp b/mlir/test/lib/Transforms/TestTopologicalSort.cpp --- a/mlir/test/lib/Transforms/TestTopologicalSort.cpp +++ b/mlir/test/lib/Transforms/TestTopologicalSort.cpp @@ -42,7 +42,7 @@ // If the root has an "ordered" attribute, we fill the selectedOps // vector in a certain order. int64_t pos = - selected->getAttr("selected").cast().getInt(); + cast(selected->getAttr("selected")).getInt(); if (pos >= static_cast(selectedOps.size())) selectedOps.append(pos + 1 - selectedOps.size(), nullptr); selectedOps[pos] = selected; diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -317,10 +317,10 @@ SerializedAffineMap &value) { assert(rawYamlContext); auto *yamlContext = static_cast(rawYamlContext); - if (auto attr = mlir::parseAttribute(scalar, yamlContext->mlirContext) - .dyn_cast_or_null()) + if (auto attr = dyn_cast_or_null( + mlir::parseAttribute(scalar, yamlContext->mlirContext))) value.affineMapAttr = attr; - else if (!value.affineMapAttr || !value.affineMapAttr.isa()) + else if (!value.affineMapAttr || !isa(value.affineMapAttr)) return "could not parse as an affine map attribute"; return StringRef(); } diff --git a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp --- a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp +++ b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp @@ -36,18 +36,18 @@ ASSERT_EQ(subElementTypes.size(), 4U); // !llvm.ptr> - ASSERT_TRUE(subElementTypes[0].isa()); + ASSERT_TRUE(isa(subElementTypes[0])); // !llvm.struct<"bar",...> - auto structType = subElementTypes[1].dyn_cast(); + auto structType = dyn_cast(subElementTypes[1]); ASSERT_TRUE(bool(structType)); ASSERT_TRUE(structType.getName().equals("bar")); // !llvm.ptr> - ASSERT_TRUE(subElementTypes[2].isa()); + ASSERT_TRUE(isa(subElementTypes[2])); // !llvm.struct<"foo",...> - structType = subElementTypes[3].dyn_cast(); + structType = dyn_cast(subElementTypes[3]); ASSERT_TRUE(bool(structType)); ASSERT_TRUE(structType.getName().equals("foo")); } diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -278,7 +278,7 @@ // Check that we cast to this attribute when possible. Attribute genericAttr = attr; - EXPECT_TRUE(genericAttr.template isa()); + EXPECT_TRUE(isa(genericAttr)); } template static void checkNativeIntAccess(Builder &builder, size_t intWidth) { @@ -330,9 +330,9 @@ Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get( type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data)); - EXPECT_TRUE(i32ResourceAttr.isa()); - EXPECT_FALSE(i32ResourceAttr.isa()); - EXPECT_FALSE(i32ResourceAttr.isa()); + EXPECT_TRUE(isa(i32ResourceAttr)); + EXPECT_FALSE(isa(i32ResourceAttr)); + EXPECT_FALSE(isa(i32ResourceAttr)); } TEST(DenseResourceElementsAttrTest, CheckInvalidData) { @@ -407,17 +407,17 @@ // Only index (0, 0) contains an element, others are supposed to return // the zero/empty value. auto zeroIntValue = - sparseInt.getValues()[{1, 1}].cast(); + cast(sparseInt.getValues()[{1, 1}]); EXPECT_EQ(zeroIntValue.getInt(), 0); EXPECT_TRUE(zeroIntValue.getType() == intTy); auto zeroFloatValue = - sparseFloat.getValues()[{1, 1}].cast(); + cast(sparseFloat.getValues()[{1, 1}]); EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f); EXPECT_TRUE(zeroFloatValue.getType() == floatTy); auto zeroStringValue = - sparseString.getValues()[{1, 1}].cast(); + cast(sparseString.getValues()[{1, 1}]); EXPECT_TRUE(zeroStringValue.getValue().empty()); EXPECT_TRUE(zeroStringValue.getType() == stringTy); } diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp --- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp +++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp @@ -61,11 +61,11 @@ // Check that the type has no interface. IntegerType i8 = IntegerType::get(&context, 8); - ASSERT_FALSE(i8.isa()); + ASSERT_FALSE(isa(i8)); // Attach an interface and check that the type now has the interface. IntegerType::attachInterface(context); - TestExternalTypeInterface iface = i8.dyn_cast(); + TestExternalTypeInterface iface = dyn_cast(i8); ASSERT_TRUE(iface != nullptr); EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u); EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u); @@ -74,9 +74,9 @@ // Same, but with the default implementation overridden. FloatType flt = Float32Type::get(&context); - ASSERT_FALSE(flt.isa()); + ASSERT_FALSE(isa(flt)); Float32Type::attachInterface(context); - iface = flt.dyn_cast(); + iface = dyn_cast(flt); ASSERT_TRUE(iface != nullptr); EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u); EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u); @@ -86,7 +86,7 @@ // Other contexts shouldn't have the attribute attached. MLIRContext other; IntegerType i8other = IntegerType::get(&other, 8); - EXPECT_FALSE(i8other.isa()); + EXPECT_FALSE(isa(i8other)); } /// External interface model for the test type from the test dialect. @@ -111,7 +111,7 @@ MLIRContext context(registry); context.loadDialect(); test::TestType testType = test::TestType::get(&context); - auto iface = testType.dyn_cast(); + auto iface = dyn_cast(testType); ASSERT_TRUE(iface != nullptr); EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u); EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u); @@ -130,9 +130,9 @@ MLIRContext context; context.loadDialect(); test::TestType testType = test::TestType::get(&context); - EXPECT_FALSE(testType.isa()); + EXPECT_FALSE(isa(testType)); context.appendDialectRegistry(registry); - EXPECT_TRUE(testType.isa()); + EXPECT_TRUE(isa(testType)); } TEST(InterfaceAttachment, RepeatedRegistration) { @@ -156,13 +156,13 @@ MLIRContext context(registry); IntegerType i16 = IntegerType::get(&context, 16); - EXPECT_TRUE(i16.isa()); + EXPECT_TRUE(isa(i16)); MLIRContext initiallyEmpty; IntegerType i32 = IntegerType::get(&initiallyEmpty, 32); - EXPECT_FALSE(i32.isa()); + EXPECT_FALSE(isa(i32)); initiallyEmpty.appendDialectRegistry(registry); - EXPECT_TRUE(i32.isa()); + EXPECT_TRUE(isa(i32)); } /// The interface provides a default implementation that expects @@ -181,9 +181,8 @@ : public TestExternalFallbackTypeInterface::FallbackModel< TestExternalFallbackTypeVectorModel> { unsigned getBitwidth(Type type) const { - IntegerType elementType = type.cast() - .getElementType() - .dyn_cast_or_null(); + IntegerType elementType = + dyn_cast_or_null(type.cast().getElementType()); return elementType ? elementType.getWidth() : 0; } }; @@ -193,16 +192,16 @@ // Just check that we can attach the interface. IntegerType i8 = IntegerType::get(&context, 8); - ASSERT_FALSE(i8.isa()); + ASSERT_FALSE(isa(i8)); IntegerType::attachInterface(context); - ASSERT_TRUE(i8.isa()); + ASSERT_TRUE(isa(i8)); // Call the method so it is guaranteed not to be instantiated. VectorType vec = VectorType::get({42}, i8); - ASSERT_FALSE(vec.isa()); + ASSERT_FALSE(isa(vec)); VectorType::attachInterface(context); - ASSERT_TRUE(vec.isa()); - EXPECT_EQ(vec.cast().getBitwidth(), 8u); + ASSERT_TRUE(isa(vec)); + EXPECT_EQ(cast(vec).getBitwidth(), 8u); } /// External model for attribute interfaces. @@ -210,7 +209,7 @@ : public TestExternalAttrInterface::ExternalModel< TestExternalIntegerAttrModel, IntegerAttr> { const Dialect *getDialectPtr(Attribute attr) const { - return &attr.cast().getDialect(); + return &cast(attr).getDialect(); } static int getSomeNumber() { return 42; } @@ -222,9 +221,9 @@ // Attribute interfaces use the exact same mechanism as types, so just check // that the basics work for attributes. IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42); - ASSERT_FALSE(attr.isa()); + ASSERT_FALSE(isa(attr)); IntegerAttr::attachInterface(context); - auto iface = attr.dyn_cast(); + auto iface = dyn_cast(attr); ASSERT_TRUE(iface != nullptr); EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect()); EXPECT_EQ(iface.getSomeNumber(), 42); @@ -253,14 +252,14 @@ MLIRContext context(registry); context.loadDialect(); auto attr = test::SimpleAAttr::get(&context); - EXPECT_TRUE(attr.isa()); + EXPECT_TRUE(isa(attr)); MLIRContext initiallyEmpty; initiallyEmpty.loadDialect(); attr = test::SimpleAAttr::get(&initiallyEmpty); - EXPECT_FALSE(attr.isa()); + EXPECT_FALSE(isa(attr)); initiallyEmpty.appendDialectRegistry(registry); - EXPECT_TRUE(attr.isa()); + EXPECT_TRUE(isa(attr)); } /// External interface model for the module operation. Only provides non-default diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp --- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp +++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp @@ -152,16 +152,16 @@ static unsigned getTypeSizeInBits(Type type, const DataLayout &dataLayout, DataLayoutEntryListRef params) { // Make a recursive query. - if (type.isa()) + if (isa(type)) return dataLayout.getTypeSizeInBits( IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth())); // Handle built-in types that are not handled by the default process. - if (auto iType = type.dyn_cast()) { + if (auto iType = dyn_cast(type)) { for (DataLayoutEntryInterface entry : params) if (entry.getKey().dyn_cast() == type) return 8 * - entry.getValue().cast().getValue().getZExtValue(); + cast(entry.getValue()).getValue().getZExtValue(); return 8 * iType.getIntOrFloatBitWidth(); } @@ -217,7 +217,7 @@ void printAttribute(Attribute attr, DialectAsmPrinter &printer) const override { printer << "spec<"; - llvm::interleaveComma(attr.cast().getEntries(), + llvm::interleaveComma(cast(attr).getEntries(), printer); printer << ">"; } @@ -244,7 +244,7 @@ } void printType(Type type, DialectAsmPrinter &printer) const override { - if (type.isa()) + if (isa(type)) printer << "single_query"; else printer << "no_layout"; diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp --- a/mlir/unittests/Pass/PassManagerTest.cpp +++ b/mlir/unittests/Pass/PassManagerTest.cpp @@ -75,12 +75,12 @@ // Verify that each function got annotated with expected attributes. for (func::FuncOp func : module->getOps()) { - ASSERT_TRUE(func->getAttr("isFunc").isa()); - EXPECT_TRUE(func->getAttr("isFunc").cast().getValue()); + ASSERT_TRUE(isa(func->getAttr("isFunc"))); + EXPECT_TRUE(cast(func->getAttr("isFunc")).getValue()); bool isSecret = func.getName() == "secret"; - ASSERT_TRUE(func->getAttr("isSecret").isa()); - EXPECT_EQ(func->getAttr("isSecret").cast().getValue(), isSecret); + ASSERT_TRUE(isa(func->getAttr("isSecret"))); + EXPECT_EQ(cast(func->getAttr("isSecret")).getValue(), isSecret); } }