diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -219,76 +219,78 @@ /// conversion exists, success otherwise. If the new set of types is empty, /// the type is removed and any usages of the existing value are expected to /// be removed during conversion. - LogicalResult convertType(Type t, SmallVectorImpl &results); + LogicalResult convertType(Type t, SmallVectorImpl &results) const; /// This hook simplifies defining 1-1 type conversions. This function returns /// the type to convert to on success, and a null type on failure. - Type convertType(Type t); + Type convertType(Type t) const; /// Attempts a 1-1 type conversion, expecting the result type to be /// `TargetType`. Returns the converted type cast to `TargetType` on success, /// and a null type on conversion or cast failure. - template - TargetType convertType(Type t) { + template TargetType convertType(Type t) const { return dyn_cast_or_null(convertType(t)); } /// Convert the given set of types, filling 'results' as necessary. This /// returns failure if the conversion of any of the types fails, success /// otherwise. - LogicalResult convertTypes(TypeRange types, SmallVectorImpl &results); + LogicalResult convertTypes(TypeRange types, + SmallVectorImpl &results) const; /// Return true if the given type is legal for this type converter, i.e. the /// type converts to itself. - bool isLegal(Type type); + bool isLegal(Type type) const; + /// Return true if all of the given types are legal for this type converter. template std::enable_if_t::value && !std::is_convertible::value, bool> - isLegal(RangeT &&range) { + isLegal(RangeT &&range) const { return llvm::all_of(range, [this](Type type) { return isLegal(type); }); } /// Return true if the given operation has legal operand and result types. - bool isLegal(Operation *op); + bool isLegal(Operation *op) const; /// Return true if the types of block arguments within the region are legal. - bool isLegal(Region *region); + bool isLegal(Region *region) const; /// Return true if the inputs and outputs of the given function type are /// legal. - bool isSignatureLegal(FunctionType ty); + bool isSignatureLegal(FunctionType ty) const; /// This method allows for converting a specific argument of a signature. It /// takes as inputs the original argument input number, type. /// On success, it populates 'result' with any new mappings. LogicalResult convertSignatureArg(unsigned inputNo, Type type, - SignatureConversion &result); + SignatureConversion &result) const; LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, - unsigned origInputOffset = 0); + unsigned origInputOffset = 0) const; /// This function converts the type signature of the given block, by invoking /// 'convertSignatureArg' for each argument. This function should return a /// valid conversion for the signature on success, std::nullopt otherwise. - std::optional convertBlockSignature(Block *block); + std::optional convertBlockSignature(Block *block) const; /// Materialize a conversion from a set of types into one result type by /// generating a cast sequence of some kind. See the respective /// `add*Materialization` for more information on the context for these /// methods. Value materializeArgumentConversion(OpBuilder &builder, Location loc, - Type resultType, ValueRange inputs) { + Type resultType, + ValueRange inputs) const { return materializeConversion(argumentMaterializations, builder, loc, resultType, inputs); } Value materializeSourceConversion(OpBuilder &builder, Location loc, - Type resultType, ValueRange inputs) { + Type resultType, ValueRange inputs) const { return materializeConversion(sourceMaterializations, builder, loc, resultType, inputs); } Value materializeTargetConversion(OpBuilder &builder, Location loc, - Type resultType, ValueRange inputs) { + Type resultType, ValueRange inputs) const { return materializeConversion(targetMaterializations, builder, loc, resultType, inputs); } @@ -297,7 +299,8 @@ /// the registered conversion functions. If no applicable conversion has been /// registered, return std::nullopt. Note that the empty attribute/`nullptr` /// is a valid return value for this function. - std::optional convertTypeAttribute(Type type, Attribute attr); + std::optional convertTypeAttribute(Type type, + Attribute attr) const; private: /// The signature of the callback used to convert a type. If the new set of @@ -316,16 +319,17 @@ /// Attempt to materialize a conversion using one of the provided /// materialization functions. - Value materializeConversion( - MutableArrayRef materializations, - OpBuilder &builder, Location loc, Type resultType, ValueRange inputs); + Value + materializeConversion(ArrayRef materializations, + OpBuilder &builder, Location loc, Type resultType, + ValueRange inputs) const; /// Generate a wrapper for the given callback. This allows for accepting /// different callback forms, that all compose into a single version. /// With callback of form: `std::optional(T)` template std::enable_if_t, ConversionCallbackFn> - wrapCallback(FnT &&callback) { + wrapCallback(FnT &&callback) const { return wrapCallback( [callback = std::forward(callback)]( T type, SmallVectorImpl &results, ArrayRef) { @@ -343,7 +347,7 @@ template std::enable_if_t &>, ConversionCallbackFn> - wrapCallback(FnT &&callback) { + wrapCallback(FnT &&callback) const { return wrapCallback( [callback = std::forward(callback)]( T type, SmallVectorImpl &results, ArrayRef) { @@ -356,7 +360,7 @@ std::enable_if_t< std::is_invocable_v &, ArrayRef>, ConversionCallbackFn> - wrapCallback(FnT &&callback) { + wrapCallback(FnT &&callback) const { return [callback = std::forward(callback)]( Type type, SmallVectorImpl &results, ArrayRef callStack) -> std::optional { @@ -378,7 +382,7 @@ /// may take any subclass of `Type` and the wrapper will check for the target /// type to be of the expected class before calling the callback. template - MaterializationCallbackFn wrapMaterialization(FnT &&callback) { + MaterializationCallbackFn wrapMaterialization(FnT &&callback) const { return [callback = std::forward(callback)]( OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> std::optional { @@ -394,7 +398,7 @@ /// callback. template TypeAttributeConversionCallbackFn - wrapTypeAttributeConversion(FnT &&callback) { + wrapTypeAttributeConversion(FnT &&callback) const { return [callback = std::forward(callback)]( Type type, Attribute attr) -> AttributeConversionResult { if (T derivedType = dyn_cast(type)) { @@ -428,13 +432,13 @@ /// A set of cached conversions to avoid recomputing in the common case. /// Direct 1-1 conversions are the most common, so this cache stores the /// successful 1-1 conversions as well as all failed conversions. - DenseMap cachedDirectConversions; + mutable DenseMap cachedDirectConversions; /// This cache stores the successful 1->N conversions, where N != 1. - DenseMap> cachedMultiConversions; + mutable DenseMap> cachedMultiConversions; /// Stores the types that are being converted in the case when convertType /// is being called recursively to convert nested types. - SmallVector conversionCallStack; + mutable SmallVector conversionCallStack; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2906,7 +2906,7 @@ } LogicalResult TypeConverter::convertType(Type t, - SmallVectorImpl &results) { + SmallVectorImpl &results) const { auto existingIt = cachedDirectConversions.find(t); if (existingIt != cachedDirectConversions.end()) { if (existingIt->second) @@ -2925,7 +2925,7 @@ conversionCallStack.push_back(t); auto popConversionCallStack = llvm::make_scope_exit([this]() { conversionCallStack.pop_back(); }); - for (ConversionCallbackFn &converter : llvm::reverse(conversions)) { + for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) { if (std::optional result = converter(t, results, conversionCallStack)) { if (!succeeded(*result)) { @@ -2943,7 +2943,7 @@ return failure(); } -Type TypeConverter::convertType(Type t) { +Type TypeConverter::convertType(Type t) const { // Use the multi-type result version to convert the type. SmallVector results; if (failed(convertType(t, results))) @@ -2953,31 +2953,35 @@ return results.size() == 1 ? results.front() : nullptr; } -LogicalResult TypeConverter::convertTypes(TypeRange types, - SmallVectorImpl &results) { +LogicalResult +TypeConverter::convertTypes(TypeRange types, + SmallVectorImpl &results) const { for (Type type : types) if (failed(convertType(type, results))) return failure(); return success(); } -bool TypeConverter::isLegal(Type type) { return convertType(type) == type; } -bool TypeConverter::isLegal(Operation *op) { +bool TypeConverter::isLegal(Type type) const { + return convertType(type) == type; +} +bool TypeConverter::isLegal(Operation *op) const { return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes()); } -bool TypeConverter::isLegal(Region *region) { +bool TypeConverter::isLegal(Region *region) const { return llvm::all_of(*region, [this](Block &block) { return isLegal(block.getArgumentTypes()); }); } -bool TypeConverter::isSignatureLegal(FunctionType ty) { +bool TypeConverter::isSignatureLegal(FunctionType ty) const { return isLegal(llvm::concat(ty.getInputs(), ty.getResults())); } -LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type, - SignatureConversion &result) { +LogicalResult +TypeConverter::convertSignatureArg(unsigned inputNo, Type type, + SignatureConversion &result) const { // Try to convert the given input type. SmallVector convertedTypes; if (failed(convertType(type, convertedTypes))) @@ -2991,9 +2995,10 @@ result.addInputs(inputNo, convertedTypes); return success(); } -LogicalResult TypeConverter::convertSignatureArgs(TypeRange types, - SignatureConversion &result, - unsigned origInputOffset) { +LogicalResult +TypeConverter::convertSignatureArgs(TypeRange types, + SignatureConversion &result, + unsigned origInputOffset) const { for (unsigned i = 0, e = types.size(); i != e; ++i) if (failed(convertSignatureArg(origInputOffset + i, types[i], result))) return failure(); @@ -3001,16 +3006,16 @@ } Value TypeConverter::materializeConversion( - MutableArrayRef materializations, - OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) { - for (MaterializationCallbackFn &fn : llvm::reverse(materializations)) + ArrayRef materializations, OpBuilder &builder, + Location loc, Type resultType, ValueRange inputs) const { + for (const MaterializationCallbackFn &fn : llvm::reverse(materializations)) if (std::optional result = fn(builder, resultType, inputs, loc)) return *result; return nullptr; } -auto TypeConverter::convertBlockSignature(Block *block) - -> std::optional { +std::optional +TypeConverter::convertBlockSignature(Block *block) const { SignatureConversion conversion(block->getNumArguments()); if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion))) return std::nullopt; @@ -3052,9 +3057,9 @@ return impl.getPointer(); } -std::optional TypeConverter::convertTypeAttribute(Type type, - Attribute attr) { - for (TypeAttributeConversionCallbackFn &fn : +std::optional +TypeConverter::convertTypeAttribute(Type type, Attribute attr) const { + for (const TypeAttributeConversionCallbackFn &fn : llvm::reverse(typeAttributeConversions)) { AttributeConversionResult res = fn(type, attr); if (res.hasResult())