diff --git a/mlir/docs/Tutorials/Toy/Ch-7.md b/mlir/docs/Tutorials/Toy/Ch-7.md --- a/mlir/docs/Tutorials/Toy/Ch-7.md +++ b/mlir/docs/Tutorials/Toy/Ch-7.md @@ -287,8 +287,7 @@ return nullptr; // Check that the type is either a TensorType or another StructType. - if (!elementType.isa() && - !elementType.isa()) { + if (!elementType.isa()) { parser.emitError(typeLoc, "element type for a struct must either " "be a TensorType or a StructType, got: ") << elementType; diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -510,8 +510,7 @@ return nullptr; // Check that the type is either a TensorType or another StructType. - if (!elementType.isa() && - !elementType.isa()) { + if (!elementType.isa()) { parser.emitError(typeLoc, "element type for a struct must either " "be a TensorType or a StructType, got: ") << elementType; diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -139,15 +139,12 @@ StructuredIndexed(Value v, ArrayRef indexings) : value(v), exprs(indexings.begin(), indexings.end()) { - assert((v.getType().isa() || - v.getType().isa() || - v.getType().isa()) && + assert((v.getType().isa()) && "MemRef, RankedTensor or Vector expected"); } StructuredIndexed(Type t, ArrayRef indexings) : type(t), exprs(indexings.begin(), indexings.end()) { - assert((t.isa() || t.isa() || - t.isa()) && + assert((t.isa()) && "MemRef, RankedTensor or Vector expected"); } diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -85,6 +85,8 @@ bool operator!() const { return impl == nullptr; } template bool isa() const; + template + bool isa() const; template U dyn_cast() const; template U dyn_cast_or_null() const; template U cast() const; @@ -1630,6 +1632,12 @@ assert(impl && "isa<> used on a null attribute."); return U::classof(*this); } + +template +bool Attribute::isa() const { + return isa() || isa(); +} + template U Attribute::dyn_cast() const { return isa() ? U(impl) : U(nullptr); } diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -97,9 +97,9 @@ return false; auto type = op->getResult(0).getType(); - if (type.isa() || type.isa()) + if (type.isa()) return attr_value_binder(bind_value).match(attr); - if (type.isa() || type.isa()) { + if (type.isa()) { if (auto splatAttr = attr.dyn_cast()) { return attr_value_binder(bind_value) .match(splatAttr.getSplatValue()); diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -357,7 +357,7 @@ /// Returns true of the given type can be used as an element of a vector type. /// In particular, vectors can consist of integer or float primitives. static bool isValidElementType(Type t) { - return t.isa() || t.isa(); + return t.isa(); } ArrayRef getShape() const; @@ -381,9 +381,8 @@ // Note: Non standard/builtin types are allowed to exist within tensor // types. Dialects are expected to verify that tensor types have a valid // element type within that dialect. - return type.isa() || type.isa() || - type.isa() || type.isa() || - type.isa() || type.isa() || + return type.isa() || (type.getKind() > Type::Kind::LAST_STANDARD_TYPE); } diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -121,6 +121,8 @@ bool operator!() const { return impl == nullptr; } template bool isa() const; + template + bool isa() const; template U dyn_cast() const; template U dyn_cast_or_null() const; template U cast() const; @@ -271,6 +273,12 @@ assert(impl && "isa<> used on a null type."); return U::classof(*this); } + +template +bool Type::isa() const { + return isa() || isa(); +} + template U Type::dyn_cast() const { return isa() ? U(impl) : U(nullptr); } diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -81,6 +81,12 @@ assert(*this && "isa<> used on a null type."); return U::classof(*this); } + + template + bool isa() const { + return isa() || isa(); + } + template U dyn_cast() const { return isa() ? U(ownerAndKind) : U(nullptr); } diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -956,8 +956,7 @@ // Walk this 'affine.for' operation to gather all memory regions. auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult { - if (!isa(opInst) && - !isa(opInst)) { + if (!isa(opInst)) { // Neither load nor a store op. return WalkResult::advance(); } @@ -1017,11 +1016,9 @@ // Collect all load and store ops in loop nest rooted at 'forOp'. SmallVector loadAndStoreOpInsts; auto walkResult = forOp.walk([&](Operation *opInst) -> WalkResult { - if (isa(opInst) || - isa(opInst)) + if (isa(opInst)) loadAndStoreOpInsts.push_back(opInst); - else if (!isa(opInst) && !isa(opInst) && - !isa(opInst) && + else if (!isa(opInst) && !MemoryEffectOpInterface::hasNoEffect(opInst)) return WalkResult::interrupt(); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -303,7 +303,7 @@ auto converted = convertType(t).dyn_cast_or_null(); if (!converted) return {}; - if (t.isa() || t.isa()) + if (t.isa()) converted = converted.getPointerTo(); inputs.push_back(converted); } @@ -1044,7 +1044,7 @@ FunctionType type, SmallVectorImpl &argsInfo) const { argsInfo.reserve(type.getNumInputs()); for (auto en : llvm::enumerate(type.getInputs())) { - if (en.value().isa() || en.value().isa()) + if (en.value().isa()) argsInfo.push_back({en.index(), en.value()}); } } diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -518,7 +518,7 @@ return failure(); // std.constant should only have vector or tenor types. - assert(srcType.isa() || srcType.isa()); + assert((srcType.isa())); auto dstType = typeConverter.convertType(srcType); if (!dstType) diff --git a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp @@ -117,7 +117,7 @@ return ValueBuilder(lhs, rhs); } else if (thisType.isa()) { return ValueBuilder(lhs, rhs); - } else if (thisType.isa() || thisType.isa()) { + } else if (thisType.isa()) { auto aggregateType = thisType.cast(); if (aggregateType.getElementType().isSignlessInteger()) return ValueBuilder(lhs, rhs); diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -218,7 +218,7 @@ nest->walk([&](Operation *op) { if (auto forOp = dyn_cast(op)) promoteIfSingleIteration(forOp); - else if (isa(op) || isa(op)) + else if (isa(op)) copyOps.push_back(op); }); diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp @@ -80,7 +80,7 @@ // If the body of a predicated region has a for loop, we don't hoist the // 'affine.if'. return false; - } else if (isa(op) || isa(op)) { + } else if (isa(op)) { // TODO(asabne): Support DMA ops. return false; } else if (!isa(op)) { @@ -91,7 +91,7 @@ for (auto *user : memref.getUsers()) { // If this memref has a user that is a DMA, give up because these // operations write to this memref. - if (isa(op) || isa(op)) { + if (isa(op)) { return false; } // If the memref used by the load/store is used in a store elsewhere in diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -923,11 +923,11 @@ return nullptr; // Fuse when consumer is GenericOp or IndexedGenericOp. - if (isa(consumer) || isa(consumer)) { + if (isa(consumer)) { auto linalgOpConsumer = cast(consumer); if (!linalgOpConsumer.hasTensorSemantics()) return nullptr; - if (isa(producer) || isa(producer)) { + if (isa(producer)) { auto linalgOpProducer = cast(producer); if (linalgOpProducer.hasTensorSemantics()) return FuseGenericOpsOnTensors::fuse(linalgOpProducer, linalgOpConsumer, diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -46,7 +46,7 @@ static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) { if (auto typeAttr = quantSpec.dyn_cast()) { Type spec = typeAttr.getValue(); - if (spec.isa() || spec.isa()) + if (spec.isa()) return false; // The spec should be either a quantized type which is compatible to the diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp --- a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp @@ -69,8 +69,7 @@ } // Is the constant value a type expressed in a way that we support? - if (!value.isa() && !value.isa() && - !value.isa()) { + if (!value.isa()) { return failure(); } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1292,7 +1292,7 @@ return failure(); Type type = value.getType(); - if (type.isa() || type.isa()) { + if (type.isa()) { if (parser.parseColonType(type)) return failure(); } @@ -1827,8 +1827,8 @@ // TODO: Currently only variable initialization with specialization // constants and other variables is supported. They could be normal // constants in the module scope as well. - if (!initOp || !(isa(initOp) || - isa(initOp))) { + if (!initOp || + !isa(initOp)) { return varOp.emitOpError("initializer must be result of a " "spv.specConstant or spv.globalVariable op"); } @@ -2093,8 +2093,7 @@ static LogicalResult verify(spirv::MergeOp mergeOp) { auto *parentOp = mergeOp.getParentOp(); - if (!parentOp || - (!isa(parentOp) && !isa(parentOp))) + if (!parentOp || !isa(parentOp)) return mergeOp.emitOpError( "expected parent op to be 'spv.selection' or 'spv.loop'"); @@ -2620,9 +2619,9 @@ // SPIR-V spec: "Initializer must be an from a constant instruction or // a global (module scope) OpVariable instruction". auto *initOp = varOp.getOperand(0).getDefiningOp(); - if (!initOp || !(isa(initOp) || // for normal constant - isa(initOp) || // for spec constant - isa(initOp))) + if (!initOp || !isa(initOp)) return varOp.emitOpError("initializer must be the result of a " "constant or spv.globalVariable op"); } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1175,8 +1175,7 @@ if (value.getType() != type) return false; // Finally, check that the attribute kind is handled. - return value.isa() || value.isa() || - value.isa() || value.isa(); + return value.isa(); } void ConstantFloatOp::build(OpBuilder &builder, OperationState &result, @@ -2102,7 +2101,7 @@ // If the result type is a vector or tensor, the type can be a mask with the // same elements. Type resultType = op.getType(); - if (!resultType.isa() && !resultType.isa()) + if (!resultType.isa()) return op.emitOpError() << "expected condition to be a signless i1, but got " << conditionType; @@ -2221,8 +2220,7 @@ assert(operands.size() == 1 && "splat takes one operand"); auto constOperand = operands.front(); - if (!constOperand || - (!constOperand.isa() && !constOperand.isa())) + if (!constOperand || !constOperand.isa()) return {}; auto shapedType = getType().cast(); diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -107,7 +107,7 @@ // Returns the type kind if the given type is a vector or ranked tensor type. // Returns llvm::None otherwise. auto getCompositeTypeKind = [](Type type) -> Optional { - if (type.isa() || type.isa()) + if (type.isa()) return static_cast(type.getKind()); return llvm::None; }; diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -337,7 +337,7 @@ } static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) { - if (type.isa() || type.isa()) + if (type.isa()) return success(); return emitError(loc, "expected integer or index type"); } @@ -1090,7 +1090,7 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, ArrayRef data, bool isSplat) { - assert((type.isa() || type.isa()) && + assert((type.isa()) && "type must be ranked tensor or vector"); assert(type.hasStaticShape() && "type must have static shape"); return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements, @@ -1247,7 +1247,7 @@ DenseElementsAttr values) { assert(indices.getType().getElementType().isInteger(64) && "expected sparse indices to be 64-bit integer values"); - assert((type.isa() || type.isa()) && + assert((type.isa()) && "type must be ranked tensor or vector"); assert(type.hasStaticShape() && "type must have static shape"); return Base::get(type.getContext(), StandardAttributes::SparseElements, type, diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -72,11 +72,11 @@ } bool Type::isSignlessIntOrIndex() { - return isa() || isSignlessInteger(); + return isSignlessInteger() || isa(); } bool Type::isSignlessIntOrIndexOrFloat() { - return isa() || isSignlessInteger() || isa(); + return isSignlessInteger() || isa(); } bool Type::isSignlessIntOrFloat() { @@ -85,7 +85,7 @@ bool Type::isIntOrIndex() { return isa() || isIndex(); } -bool Type::isIntOrFloat() { return isa() || isa(); } +bool Type::isIntOrFloat() { return isa(); } bool Type::isIntOrIndexOrFloat() { return isIntOrFloat() || isIndex(); } @@ -200,7 +200,7 @@ int64_t ShapedType::getRank() const { return getShape().size(); } bool ShapedType::hasRank() const { - return !isa() && !isa(); + return !isa(); } int64_t ShapedType::getDimSize(unsigned idx) const { @@ -233,7 +233,7 @@ // Tensors can have vectors and other tensors as elements, other shaped types // cannot. assert(isa() && "unsupported element type"); - assert((elementType.isa() || elementType.isa()) && + assert((elementType.isa()) && "unsupported tensor element type"); return getNumElements() * elementType.cast().getSizeInBits(); } @@ -398,8 +398,8 @@ auto *context = elementType.getContext(); // Check that memref is formed from allowed types. - if (!elementType.isIntOrFloat() && !elementType.isa() && - !elementType.isa()) + if (!elementType.isIntOrFloat() && + !elementType.isa()) return emitOptionalError(location, "invalid memref element type"), MemRefType(); @@ -476,8 +476,8 @@ UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType, unsigned memorySpace) { // Check that memref is formed from allowed types. - if (!elementType.isIntOrFloat() && !elementType.isa() && - !elementType.isa()) + if (!elementType.isIntOrFloat() && + !elementType.isa()) return emitError(loc, "invalid memref element type"); return success(); } diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -397,7 +397,7 @@ for (Attribute attr : llvm::drop_begin(attrRange, index)) { /// Check for a nested container attribute, these will also need to be /// walked. - if (attr.isa() || attr.isa()) { + if (attr.isa()) { attrWorklist.push_back(attr); curAccessChain.push_back(-1); return WalkResult::advance(); diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -345,7 +345,7 @@ return apVal ? FloatAttr::get(floatType, *apVal) : Attribute(); } - if (!type.isa() && !type.isa()) + if (!type.isa()) return emitError(loc, "integer literal not valid for specified type"), nullptr; @@ -823,7 +823,7 @@ return nullptr; } - if (!type.isa() && !type.isa()) { + if (!type.isa()) { emitError("elements literal must be a ranked tensor or vector type"); return nullptr; } diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -217,8 +217,8 @@ return nullptr; // Check that memref is formed from allowed types. - if (!elementType.isIntOrFloat() && !elementType.isa() && - !elementType.isa()) + if (!elementType.isIntOrFloat() && + !elementType.isa()) return emitError(typeLoc, "invalid memref element type"), nullptr; // Parse semi-affine-map-composition. diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -773,8 +773,7 @@ LogicalResult ModuleTranslation::checkSupportedModuleOps(Operation *m) { for (Operation &o : getModuleBody(m).getOperations()) - if (!isa(&o) && !isa(&o) && - !o.isKnownTerminator()) + if (!isa(&o) && !o.isKnownTerminator()) return o.emitOpError("unsupported module-level operation"); return success(); } diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -294,7 +294,7 @@ unsigned count = 0; stats->opCountMap[childForOp] = 0; for (auto &op : *forOp.getBody()) { - if (!isa(op) && !isa(op)) + if (!isa(op)) ++count; } stats->opCountMap[childForOp] = count; diff --git a/mlir/test/lib/Transforms/TestMemRefDependenceCheck.cpp b/mlir/test/lib/Transforms/TestMemRefDependenceCheck.cpp --- a/mlir/test/lib/Transforms/TestMemRefDependenceCheck.cpp +++ b/mlir/test/lib/Transforms/TestMemRefDependenceCheck.cpp @@ -103,7 +103,7 @@ // Collect the loads and stores within the function. loadsAndStores.clear(); getFunction().walk([&](Operation *op) { - if (isa(op) || isa(op)) + if (isa(op)) loadsAndStores.push_back(op); });