diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -479,9 +479,12 @@ /// Apply a signature conversion to the entry block of the given region. This /// replaces the entry block with a new block containing the updated /// signature. The new entry block to the region is returned for convenience. + /// + /// If provided, `converter` will be used for any materializations. Block * applySignatureConversion(Region *region, - TypeConverter::SignatureConversion &conversion); + TypeConverter::SignatureConversion &conversion, + TypeConverter *converter = nullptr); /// Convert the types of block arguments within the given region. This /// replaces each block with a new block containing the updated signature. The diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -741,10 +741,12 @@ Block *block, TypeConverter &converter, TypeConverter::SignatureConversion *conversion = nullptr); - /// Apply a signature conversion on the given region. + /// Apply a signature conversion on the given region, using `converter` for + /// materializations if not null. Block * applySignatureConversion(Region *region, - TypeConverter::SignatureConversion &conversion); + TypeConverter::SignatureConversion &conversion, + TypeConverter *converter); /// Convert the types of block arguments within the given region. FailureOr @@ -1145,9 +1147,11 @@ } Block *ConversionPatternRewriterImpl::applySignatureConversion( - Region *region, TypeConverter::SignatureConversion &conversion) { + Region *region, TypeConverter::SignatureConversion &conversion, + TypeConverter *converter) { if (!region->empty()) { - return *convertBlockSignature(®ion->front(), defaultTypeConverter, + return *convertBlockSignature(®ion->front(), + converter ? *converter : defaultTypeConverter, &conversion); } return nullptr; @@ -1335,8 +1339,9 @@ } Block *ConversionPatternRewriter::applySignatureConversion( - Region *region, TypeConverter::SignatureConversion &conversion) { - return impl->applySignatureConversion(region, conversion); + Region *region, TypeConverter::SignatureConversion &conversion, + TypeConverter *converter) { + return impl->applySignatureConversion(region, conversion, converter); } FailureOr ConversionPatternRewriter::convertRegionTypes(