diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -105,10 +105,15 @@ /// functionality, i.e. we will traverse if the mapped value also has a mapping. struct ConversionValueMapping { /// Lookup a mapped value within the map. If a mapping for the provided value - /// does not exist then return the provided value. If `desiredType` is - /// non-null, returns the most recently mapped value with that type. If an - /// operand of that type does not exist, defaults to normal behavior. - Value lookupOrDefault(Value from, Type desiredType = nullptr) const; + /// does not exist then return the provided value. + Value lookupOrDefault(Value from) const; + + /// Lookup the latest legal value within the map. If a mapping for the + /// provided value does not exist then return the provided value. If + /// `converter` is non-null, returns the most recently mapped value with the + /// legal type. If an operand of that type does not exist, defaults to normal + /// behavior. + Value lookupLatestLegal(Value from, TypeConverter *converter) const; /// Lookup a mapped value within the map, or return null if a mapping does not /// exist. If a mapping exists, this follows the same behavior of @@ -127,22 +132,24 @@ }; } // end anonymous namespace -Value ConversionValueMapping::lookupOrDefault(Value from, - Type desiredType) const { - // If there was no desired type, simply find the leaf value. - if (!desiredType) { - // If this value had a valid mapping, unmap that value as well in the case - // that it was also replaced. - while (auto mappedValue = mapping.lookupOrNull(from)) - from = mappedValue; - return from; - } +Value ConversionValueMapping::lookupOrDefault(Value from) const { + // If this value had a valid mapping, unmap that value as well in the case + // that it was also replaced. + while (auto mappedValue = mapping.lookupOrNull(from)) + from = mappedValue; + return from; +} - // Otherwise, try to find the deepest value that has the desired type. - Value desiredValue; +Value ConversionValueMapping::lookupLatestLegal( + Value from, TypeConverter *converter) const { + if (!converter) + return lookupOrDefault(from); + + // Otherwise, try to find the deepest value that has the legal type. + Value legalValue; do { - if (from.getType() == desiredType) - desiredValue = from; + if (converter->isLegal(from.getType())) + legalValue = from; Value mappedValue = mapping.lookupOrNull(from); if (!mappedValue) @@ -151,7 +158,7 @@ } while (true); // If the desired value was found use it, otherwise default to the leaf value. - return desiredValue ? desiredValue : from; + return legalValue ? legalValue : from; } Value ConversionValueMapping::lookupOrNull(Value from) const { @@ -1039,22 +1046,41 @@ Value operand = it.value(); Type origType = operand.getType(); - // If a converter was provided, get the desired legal types for this - // operand. - Type desiredType; + Value newOperand = mapping.lookupLatestLegal(operand, converter); + + // Handle the case where the conversion was 1->1 and the new operand type + // isn't legal. + Type newOperandType = newOperand.getType(); if (converter) { - // If there is no legal conversion, fail to match this pattern. - legalTypes.clear(); - if (failed(converter->convertType(origType, legalTypes))) { - return notifyMatchFailure(loc, [=](Diagnostic &diag) { - diag << "unable to convert type for operand #" << it.index() - << ", type was " << origType; - }); + if (!converter->isLegal(newOperandType)) { + legalTypes.clear(); + + // If there is no legal conversion, fail to match this pattern. + if (failed(converter->convertType(origType, legalTypes))) { + return notifyMatchFailure(loc, [=](Diagnostic &diag) { + diag << "unable to convert type for operand #" << it.index() + << ", type was " << origType; + }); + } + // TODO: There currently isn't any mechanism to do 1->N type conversion + // via the PatternRewriter replacement API, so for now we just ignore + // it. + if (legalTypes.size() != 1) { + remapped.push_back(newOperand); + continue; + } + Type desiredType = legalTypes.front(); + newOperand = converter->materializeTargetConversion( + rewriter, loc, desiredType, newOperand); + if (!newOperand) { + return notifyMatchFailure(loc, [=](Diagnostic &diag) { + diag << "unable to materialize a conversion for " + "operand #" + << it.index() << ", from " << newOperandType << " to " + << desiredType; + }); + } } - // TODO: There currently isn't any mechanism to do 1->N type conversion - // via the PatternRewriter replacement API, so for now we just ignore it. - if (legalTypes.size() == 1) - desiredType = legalTypes.front(); } else { // TODO: What we should do here is just set `desiredType` to `origType` // and then handle the necessary type conversions after the conversion @@ -1062,24 +1088,7 @@ // receiving the new operands even if the types change, so we keep the // original behavior here for now until all of the patterns relying on // this get updated. - } - Value newOperand = mapping.lookupOrDefault(operand, desiredType); - - // Handle the case where the conversion was 1->1 and the new operand type - // isn't legal. - Type newOperandType = newOperand.getType(); - if (converter && desiredType && newOperandType != desiredType) { // Attempt to materialize a conversion for this new value. - newOperand = converter->materializeTargetConversion( - rewriter, loc, desiredType, newOperand); - if (!newOperand) { - return notifyMatchFailure(loc, [=](Diagnostic &diag) { - diag << "unable to materialize a conversion for " - "operand #" - << it.index() << ", from " << newOperandType << " to " - << desiredType; - }); - } } remapped.push_back(newOperand); }