diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -110,6 +110,13 @@ /// operand of that type does not exist, defaults to normal behavior. Value lookupOrDefault(Value from, Type desiredType = nullptr) const; + /// Lookup the latest legal value within the map. If a mapping for the + /// provided value does not exist then return the provided value. If + /// `converter` is non-null, returns the most recently mapped value with the + /// legal type. If an operand of that type does not exist, defaults to normal + /// behavior. + Value lookupLatestLegal(Value from, TypeConverter *converter) const; + /// Lookup a mapped value within the map, or return null if a mapping does not /// exist. If a mapping exists, this follows the same behavior of /// `lookupOrDefault`. @@ -154,6 +161,30 @@ return desiredValue ? desiredValue : from; } +Value ConversionValueMapping::lookupLatestLegal( + Value from, TypeConverter *converter) const { + if (!converter) { + while (auto mappedValue = mapping.lookupOrNull(from)) + from = mappedValue; + return from; + } + + // Otherwise, try to find the deepest value that has the legal type. + Value desiredValue; + do { + if (converter->isLegal(from.getType())) + desiredValue = from; + + Value mappedValue = mapping.lookupOrNull(from); + if (!mappedValue) + break; + from = mappedValue; + } while (true); + + // If the desired value was found use it, otherwise default to the leaf value. + return desiredValue ? desiredValue : from; +} + Value ConversionValueMapping::lookupOrNull(Value from) const { Value result = lookupOrDefault(from); return result == from ? nullptr : result; @@ -1039,22 +1070,43 @@ Value operand = it.value(); Type origType = operand.getType(); - // If a converter was provided, get the desired legal types for this - // operand. - Type desiredType; + Value newOperand = mapping.lookupLatestLegal(operand, converter); + + // Handle the case where the conversion was 1->1 and the new operand type + // isn't legal. + Type newOperandType = newOperand.getType(); if (converter) { - // If there is no legal conversion, fail to match this pattern. - legalTypes.clear(); - if (failed(converter->convertType(origType, legalTypes))) { - return notifyMatchFailure(loc, [=](Diagnostic &diag) { - diag << "unable to convert type for operand #" << it.index() - << ", type was " << origType; - }); - } - // TODO: There currently isn't any mechanism to do 1->N type conversion - // via the PatternRewriter replacement API, so for now we just ignore it. - if (legalTypes.size() == 1) + if (!converter->isLegal(newOperandType)) { + legalTypes.clear(); + + // Get the desired legal types for this operand. + Type desiredType; + // If there is no legal conversion, fail to match this pattern. + if (failed(converter->convertType(origType, legalTypes))) { + return notifyMatchFailure(loc, [=](Diagnostic &diag) { + diag << "unable to convert type for operand #" << it.index() + << ", type was " << origType; + }); + } + // TODO: There currently isn't any mechanism to do 1->N type conversion + // via the PatternRewriter replacement API, so for now we just ignore + // it. + if (legalTypes.size() != 1) { + remapped.push_back(newOperand); + continue; + }; desiredType = legalTypes.front(); + newOperand = converter->materializeTargetConversion( + rewriter, loc, desiredType, newOperand); + if (!newOperand) { + return notifyMatchFailure(loc, [=](Diagnostic &diag) { + diag << "unable to materialize a conversion for " + "operand #" + << it.index() << ", from " << newOperandType << " to " + << desiredType; + }); + } + } } else { // TODO: What we should do here is just set `desiredType` to `origType` // and then handle the necessary type conversions after the conversion @@ -1062,24 +1114,7 @@ // receiving the new operands even if the types change, so we keep the // original behavior here for now until all of the patterns relying on // this get updated. - } - Value newOperand = mapping.lookupOrDefault(operand, desiredType); - - // Handle the case where the conversion was 1->1 and the new operand type - // isn't legal. - Type newOperandType = newOperand.getType(); - if (converter && desiredType && newOperandType != desiredType) { // Attempt to materialize a conversion for this new value. - newOperand = converter->materializeTargetConversion( - rewriter, loc, desiredType, newOperand); - if (!newOperand) { - return notifyMatchFailure(loc, [=](Diagnostic &diag) { - diag << "unable to materialize a conversion for " - "operand #" - << it.index() << ", from " << newOperandType << " to " - << desiredType; - }); - } } remapped.push_back(newOperand); }