diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -113,20 +113,40 @@ /// Register a materialization function, which must be convertible to the /// following form: - /// `Optional(PatternRewriter &, T, ValueRange, Location)`, + /// `Optional(OpBuilder &, T, ValueRange, Location)`, /// where `T` is any subclass of `Type`. This function is responsible for - /// creating an operation, using the PatternRewriter and Location provided, - /// that "casts" a range of values into a single value of the given type `T`. - /// It must return a Value of the converted type on success, an `llvm::None` - /// if it failed but other materialization can be attempted, and `nullptr` on + /// creating an operation, using the OpBuilder and Location provided, that + /// "casts" a range of values into a single value of the given type `T`. It + /// must return a Value of the converted type on success, an `llvm::None` if + /// it failed but other materialization can be attempted, and `nullptr` on /// unrecoverable failure. It will only be called for (sub)types of `T`. /// Materialization functions must be provided when a type conversion /// results in more than one type, or if a type conversion may persist after /// the conversion has finished. + /// + /// This method registers a materialization that will be called when + /// converting an illegal block argument type, to a legal type. template ::template arg_t<1>> - void addMaterialization(FnT &&callback) { - registerMaterialization( + void addArgumentMaterialization(FnT &&callback) { + argumentMaterializations.emplace_back( + wrapMaterialization(std::forward(callback))); + } + /// This method registers a materialization that will be called when + /// converting a legal type to an illegal source type. This is used when + /// conversions to an illegal type must persist beyond the main conversion. + template ::template arg_t<1>> + void addSourceMaterialization(FnT &&callback) { + sourceMaterializations.emplace_back( + wrapMaterialization(std::forward(callback))); + } + /// This method registers a materialization that will be called when + /// converting type from an illegal, or source, type to a legal type. + template ::template arg_t<1>> + void addTargetMaterialization(FnT &&callback) { + targetMaterializations.emplace_back( wrapMaterialization(std::forward(callback))); } @@ -182,9 +202,24 @@ Optional convertBlockSignature(Block *block); /// Materialize a conversion from a set of types into one result type by - /// generating a cast operation of some kind. - Value materializeConversion(PatternRewriter &rewriter, Location loc, - Type resultType, ValueRange inputs); + /// generating a cast sequence of some kind. See the respective + /// `add*Materialization` for more information on the context for these + /// methods. + Value materializeArgumentConversion(OpBuilder &builder, Location loc, + Type resultType, ValueRange inputs) { + return materializeConversion(argumentMaterializations, builder, loc, + resultType, inputs); + } + Value materializeSourceConversion(OpBuilder &builder, Location loc, + Type resultType, ValueRange inputs) { + return materializeConversion(sourceMaterializations, builder, loc, + resultType, inputs); + } + Value materializeTargetConversion(OpBuilder &builder, Location loc, + Type resultType, ValueRange inputs) { + return materializeConversion(targetMaterializations, builder, loc, + resultType, inputs); + } private: /// The signature of the callback used to convert a type. If the new set of @@ -193,8 +228,15 @@ using ConversionCallbackFn = std::function(Type, SmallVectorImpl &)>; - using MaterializationCallbackFn = std::function( - PatternRewriter &, Type, ValueRange, Location)>; + /// The signature of the callback used to materialize a conversion. + using MaterializationCallbackFn = + std::function(OpBuilder &, Type, ValueRange, Location)>; + + /// Attempt to materialize a conversion using one of the provided + /// materialization functions. + Value materializeConversion( + MutableArrayRef materializations, + OpBuilder &builder, Location loc, Type resultType, ValueRange inputs); /// Generate a wrapper for the given callback. This allows for accepting /// different callback forms, that all compose into a single version. @@ -240,24 +282,21 @@ template MaterializationCallbackFn wrapMaterialization(FnT &&callback) { return [callback = std::forward(callback)]( - PatternRewriter &rewriter, Type resultType, ValueRange inputs, + OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Optional { if (T derivedType = resultType.dyn_cast()) - return callback(rewriter, derivedType, inputs, loc); + return callback(builder, derivedType, inputs, loc); return llvm::None; }; } - /// Register a materialization. - void registerMaterialization(MaterializationCallbackFn &&callback) { - materializations.emplace_back(std::move(callback)); - } - /// The set of registered conversion functions. SmallVector conversions; /// The list of registered materialization functions. - SmallVector materializations; + SmallVector argumentMaterializations; + SmallVector sourceMaterializations; + SmallVector targetMaterializations; /// A set of cached conversions to avoid recomputing in the common case. /// Direct 1-1 conversions are the most common, so this cache stores the @@ -325,7 +364,7 @@ protected: /// An optional type converter for use by this pattern. - TypeConverter *typeConverter; + TypeConverter *typeConverter = nullptr; private: using RewritePattern::rewrite; diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -154,19 +154,42 @@ // Materialization for memrefs creates descriptor structs from individual // values constituting them, when descriptors are used, i.e. more than one // value represents a memref. - addMaterialization([&](PatternRewriter &rewriter, - UnrankedMemRefType resultType, ValueRange inputs, - Location loc) -> Optional { + addArgumentMaterialization( + [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, + Location loc) -> Optional { + if (inputs.size() == 1) + return llvm::None; + return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, + inputs); + }); + addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType, + ValueRange inputs, + Location loc) -> Optional { if (inputs.size() == 1) return llvm::None; - return UnrankedMemRefDescriptor::pack(rewriter, loc, *this, resultType, - inputs); + return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs); }); - addMaterialization([&](PatternRewriter &rewriter, MemRefType resultType, - ValueRange inputs, Location loc) -> Optional { - if (inputs.size() == 1) + // Add generic source and target materializations to handle cases where + // non-LLVM types persist after an LLVM conversion. + addSourceMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> Optional { + if (inputs.size() != 1) + return llvm::None; + // FIXME: These should check LLVM::DialectCastOp can actually be constructed + // from the input and result. + return builder.create(loc, resultType, inputs[0]) + .getResult(); + }); + addTargetMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> Optional { + if (inputs.size() != 1) return llvm::None; - return MemRefDescriptor::pack(rewriter, loc, *this, resultType, inputs); + // FIXME: These should check LLVM::DialectCastOp can actually be constructed + // from the input and result. + return builder.create(loc, resultType, inputs[0]) + .getResult(); }); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -222,6 +222,16 @@ spirv::TargetEnv targetEnv(spirv::lookupTargetEnv(module)); SPIRVTypeConverter typeConverter(targetEnv); + + // Insert a bitcast in the case of a pointer type change. + typeConverter.addSourceMaterialization([](OpBuilder &builder, + spirv::PointerType type, + ValueRange inputs, Location loc) { + if (inputs.size() != 1 || !inputs[0].getType().isa()) + return Value(); + return builder.create(loc, type, inputs[0]).getResult(); + }); + OwningRewritePatternList patterns; patterns.insert(context, typeConverter); diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -77,7 +77,11 @@ Location Value::getLoc() const { if (auto *op = getDefiningOp()) return op->getLoc(); - return UnknownLoc::get(getContext()); + + // Use the location of the parent operation if this is a block argument. + // TODO: Should we just add locations to block arguments? + Operation *parentOp = cast().getOwner()->getParentOp(); + return parentOp ? parentOp->getLoc() : UnknownLoc::get(getContext()); } /// Return the Region in which this Value is defined. diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -17,6 +17,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/ScopedPrinter.h" using namespace mlir; @@ -106,8 +107,15 @@ /// functionality, i.e. we will traverse if the mapped value also has a mapping. struct ConversionValueMapping { /// Lookup a mapped value within the map. If a mapping for the provided value - /// does not exist then return the provided value. - Value lookupOrDefault(Value from) const; + /// does not exist then return the provided value. If `desiredType` is + /// non-null, returns the most recently mapped value with that type. If an + /// operand of that type does not exist, defaults to normal behavior. + Value lookupOrDefault(Value from, Type desiredType = nullptr) const; + + /// Lookup a mapped value within the map, or return null if a mapping does not + /// exist. If a mapping exists, this follows the same behavior of + /// `lookupOrDefault`. + Value lookupOrNull(Value from) const; /// Map a value to the one provided. void map(Value oldVal, Value newVal) { mapping.map(oldVal, newVal); } @@ -121,14 +129,36 @@ }; } // end anonymous namespace -/// Lookup a mapped value within the map. If a mapping for the provided value -/// does not exist then return the provided value. -Value ConversionValueMapping::lookupOrDefault(Value from) const { - // If this value had a valid mapping, unmap that value as well in the case - // that it was also replaced. - while (auto mappedValue = mapping.lookupOrNull(from)) +Value ConversionValueMapping::lookupOrDefault(Value from, + Type desiredType) const { + // If there was no desired type, simply find the leaf value. + if (!desiredType) { + // If this value had a valid mapping, unmap that value as well in the case + // that it was also replaced. + while (auto mappedValue = mapping.lookupOrNull(from)) + from = mappedValue; + return from; + } + + // Otherwise, try to find the deepest value that has the desired type. + Value desiredValue; + do { + if (from.getType() == desiredType) + desiredValue = from; + + Value mappedValue = mapping.lookupOrNull(from); + if (!mappedValue) + break; from = mappedValue; - return from; + } while (true); + + // If the desired value was found use it, otherwise default to the leaf value. + return desiredValue ? desiredValue : from; +} + +Value ConversionValueMapping::lookupOrNull(Value from) const { + Value result = lookupOrDefault(from); + return result == from ? nullptr : result; } //===----------------------------------------------------------------------===// @@ -209,10 +239,17 @@ /// its original state. void discardRewrites(Block *block); - /// Fully replace uses of the old arguments with the new, materializing cast - /// operations as necessary. + /// Fully replace uses of the old arguments with the new. void applyRewrites(ConversionValueMapping &mapping); + /// Materialize any necessary conversions for converted arguments that have + /// live users, using the provided `findLiveUser` to search for a user that + /// survives the conversion process. + LogicalResult + materializeLiveConversions(ConversionValueMapping &mapping, + OpBuilder &builder, + function_ref findLiveUser); + //===--------------------------------------------------------------------===// // Conversion //===--------------------------------------------------------------------===// @@ -307,7 +344,6 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { for (auto &info : conversionInfo) { - Block *newBlock = info.first; ConvertedBlockInfo &blockInfo = info.second; Block *origBlock = blockInfo.origBlock; @@ -318,24 +354,8 @@ // Handle the case of a 1->0 value mapping. if (!argInfo) { - // If a replacement value was given for this argument, use that to - // replace all uses. - auto argReplacementValue = mapping.lookupOrDefault(origArg); - if (argReplacementValue != origArg) { - origArg.replaceAllUsesWith(argReplacementValue); - continue; - } - // If there are any dangling uses then replace the argument with one - // generated by the type converter. This is necessary as the cast must - // persist in the IR after conversion. - if (!origArg.use_empty()) { - rewriter.setInsertionPointToStart(newBlock); - Value newArg = blockInfo.converter->materializeConversion( - rewriter, origArg.getLoc(), origArg.getType(), llvm::None); - assert(newArg && - "Couldn't materialize a block argument after 1->0 conversion"); + if (Value newArg = mapping.lookupOrNull(origArg)) origArg.replaceAllUsesWith(newArg); - } continue; } @@ -355,6 +375,59 @@ } } +LogicalResult ArgConverter::materializeLiveConversions( + ConversionValueMapping &mapping, OpBuilder &builder, + function_ref findLiveUser) { + for (auto &info : conversionInfo) { + Block *newBlock = info.first; + ConvertedBlockInfo &blockInfo = info.second; + Block *origBlock = blockInfo.origBlock; + + // Process the remapping for each of the original arguments. + for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) { + // FIXME: We should run the below checks even if the type conversion was + // 1->N, but a lot of existing lowering rely on the block argument being + // blindly replaced. Those usages should be updated, and this if should be + // removed. + if (blockInfo.argInfo[i]) + continue; + + // If the type of this argument changed and the argument is still live, we + // need to materialize a conversion. + BlockArgument origArg = origBlock->getArgument(i); + auto argReplacementValue = mapping.lookupOrDefault(origArg); + bool isDroppedArg = argReplacementValue == origArg; + if (argReplacementValue.getType() == origArg.getType() && !isDroppedArg) + continue; + Operation *liveUser = findLiveUser(origArg); + if (!liveUser) + continue; + + if (OpResult result = argReplacementValue.dyn_cast()) + rewriter.setInsertionPointAfter(result.getOwner()); + else + rewriter.setInsertionPointToStart(newBlock); + Value newArg = blockInfo.converter->materializeSourceConversion( + rewriter, origArg.getLoc(), origArg.getType(), + isDroppedArg ? ValueRange() : ValueRange(argReplacementValue)); + if (!newArg) { + InFlightDiagnostic diag = + emitError(origArg.getLoc()) + << "failed to materialize conversion for block argument #" << i + << " that remained live after conversion, type was " + << origArg.getType(); + if (!isDroppedArg) + diag << ", with target type " << argReplacementValue.getType(); + diag.attachNote(liveUser->getLoc()) + << "see existing live user here: " << *liveUser; + return failure(); + } + mapping.map(origArg, newArg); + } + } + return success(); +} + //===----------------------------------------------------------------------===// // Conversion @@ -417,8 +490,8 @@ // to pack the new values. For 1->1 mappings, if there is no materialization // provided, use the argument directly instead. auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); - Value newArg = converter.materializeConversion(rewriter, origArg.getLoc(), - origArg.getType(), replArgs); + Value newArg = converter.materializeArgumentConversion( + rewriter, origArg.getLoc(), origArg.getType(), replArgs); if (!newArg) { assert(replArgs.size() == 1 && "couldn't materialize the result of 1->N conversion"); @@ -516,13 +589,15 @@ SmallVector successors; }; -/// This class represents one requested operation replacement via 'replaceOp'. +/// This class represents one requested operation replacement via 'replaceOp' or +/// 'eraseOp`. struct OpReplacement { OpReplacement() = default; - OpReplacement(ValueRange newValues) - : newValues(newValues.begin(), newValues.end()) {} + OpReplacement(TypeConverter *converter) : converter(converter) {} - SmallVector newValues; + /// An optional type converter that can be used to materialize conversions + /// between the new and old values if necessary. + TypeConverter *converter = nullptr; }; /// The kind of the block action performed during the rewrite. Actions can be @@ -611,9 +686,14 @@ /// "numActionsToKeep" actions remains. void undoBlockActions(unsigned numActionsToKeep = 0); - /// Remap the given operands to those with potentially different types. - void remapValues(Operation::operand_range operands, - SmallVectorImpl &remapped); + /// Remap the given operands to those with potentially different types. The + /// provided type converter is used to ensure that the remapped types are + /// legal. Returns success if the operands could be remapped, failure + /// otherwise. + LogicalResult remapValues(Location loc, PatternRewriter &rewriter, + TypeConverter *converter, + Operation::operand_range operands, + SmallVectorImpl &remapped); /// Returns true if the given operation is ignored, and does not need to be /// converted. @@ -666,6 +746,11 @@ void notifyRegionWasClonedBefore(iterator_range &blocks, Location origRegionLoc); + /// Notifies that a pattern match failed for the given reason. + LogicalResult + notifyMatchFailure(Location loc, + function_ref reasonCallback); + //===--------------------------------------------------------------------===// // State //===--------------------------------------------------------------------===// @@ -712,6 +797,10 @@ /// explicitly provided. TypeConverter defaultTypeConverter; + /// The current conversion pattern that is being rewritten, or nullptr if + /// called from outside of a conversion pattern rewrite. + const ConversionPattern *currentConversionPattern = nullptr; + #ifndef NDEBUG /// A set of operations that have pending updates. This tracking isn't /// strictly necessary, and is thus only active during debug builds for extra @@ -759,11 +848,9 @@ void ConversionPatternRewriterImpl::applyRewrites() { // Apply all of the rewrites replacements requested during conversion. for (auto &repl : replacements) { - for (unsigned i = 0, e = repl.second.newValues.size(); i != e; ++i) { - if (auto newValue = repl.second.newValues[i]) - repl.first->getResult(i).replaceAllUsesWith( - mapping.lookupOrDefault(newValue)); - } + for (OpResult result : repl.first->getResults()) + if (Value newValue = mapping.lookupOrNull(result)) + result.replaceAllUsesWith(newValue); // If this operation defines any regions, drop any pending argument // rewrites. @@ -905,11 +992,61 @@ blockActions.resize(numActionsToKeep); } -void ConversionPatternRewriterImpl::remapValues( +LogicalResult ConversionPatternRewriterImpl::remapValues( + Location loc, PatternRewriter &rewriter, TypeConverter *converter, Operation::operand_range operands, SmallVectorImpl &remapped) { remapped.reserve(llvm::size(operands)); - for (Value operand : operands) - remapped.push_back(mapping.lookupOrDefault(operand)); + + SmallVector legalTypes; + for (auto it : llvm::enumerate(operands)) { + Value operand = it.value(); + Type origType = operand.getType(); + + // If a converter was provided, get the desired legal types for this + // operand. + Type desiredType; + if (converter) { + // If there is no legal conversion, fail to match this pattern. + legalTypes.clear(); + if (failed(converter->convertType(origType, legalTypes))) { + return notifyMatchFailure(loc, [=](Diagnostic &diag) { + diag << "unable to convert type for operand #" << it.index() + << ", type was " << origType; + }); + } + // TODO: There currently isn't any mechanism to do 1->N type conversion + // via the PatternRewriter replacement API, so for now we just ignore it. + if (legalTypes.size() == 1) + desiredType = legalTypes.front(); + } else { + // TODO: What we should do here is just set `desiredType` to `origType` + // and then handle the necessary type conversions after the conversion + // process has finished. Unfortunately a lot of patterns currently rely on + // receiving the new operands even if the types change, so we keep the + // original behavior here for now until all of the patterns relying on + // this get updated. + } + Value newOperand = mapping.lookupOrDefault(operand, desiredType); + + // Handle the case where the conversion was 1->1 and the new operand type + // isn't legal. + Type newOperandType = newOperand.getType(); + if (converter && desiredType && newOperandType != desiredType) { + // Attempt to materialize a conversion for this new value. + newOperand = converter->materializeTargetConversion( + rewriter, loc, desiredType, newOperand); + if (!newOperand) { + return notifyMatchFailure(loc, [=](Diagnostic &diag) { + diag << "unable to materialize a conversion for " + "operand #" + << it.index() << ", from " << newOperandType << " to " + << desiredType; + }); + } + } + remapped.push_back(newOperand); + } + return success(); } bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { @@ -987,16 +1124,22 @@ Value newValue, result; for (auto it : llvm::zip(newValues, op->getResults())) { std::tie(newValue, result) = it; - if (!newValue) + if (!newValue) { resultChanged = true; - else - mapping.map(result, newValue); + continue; + } + // Remap, and check for any result type changes. + mapping.map(result, newValue); + resultChanged |= (newValue.getType() != result.getType()); } if (resultChanged) operationsWithChangedResults.push_back(replacements.size()); // Record the requested operation replacement. - replacements.insert(std::make_pair(op, OpReplacement(newValues))); + TypeConverter *converter = nullptr; + if (currentConversionPattern) + converter = currentConversionPattern->getTypeConverter(); + replacements.insert(std::make_pair(op, OpReplacement(converter))); // Mark this operation as recursively ignored so that we don't need to // convert any nested operations. @@ -1041,6 +1184,16 @@ assert(succeeded(result) && "expected region to have no unreachable blocks"); } +LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure( + Location loc, function_ref reasonCallback) { + LLVM_DEBUG({ + Diagnostic diag(loc, DiagnosticSeverity::Remark); + reasonCallback(diag); + logger.startLine() << "** Failure : " << diag.str() << "\n"; + }); + return failure(); +} + //===----------------------------------------------------------------------===// // ConversionPatternRewriter //===----------------------------------------------------------------------===// @@ -1200,12 +1353,7 @@ /// PatternRewriter hook for notifying match failure reasons. LogicalResult ConversionPatternRewriter::notifyMatchFailure( Operation *op, function_ref reasonCallback) { - LLVM_DEBUG({ - Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark); - reasonCallback(diag); - impl->logger.startLine() << "** Failure : " << diag.str() << "\n"; - }); - return failure(); + return impl->notifyMatchFailure(op->getLoc(), reasonCallback); } /// Return a reference to the internal implementation. @@ -1221,9 +1369,22 @@ LogicalResult ConversionPattern::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { - SmallVector operands; auto &dialectRewriter = static_cast(rewriter); - dialectRewriter.getImpl().remapValues(op->getOperands(), operands); + auto &rewriterImpl = dialectRewriter.getImpl(); + + // Track the current conversion pattern in the rewriter. + assert(!rewriterImpl.currentConversionPattern && + "already inside of a pattern rewrite"); + llvm::SaveAndRestore currentPatternGuard( + rewriterImpl.currentConversionPattern, this); + + // Remap the operands of the operation. + SmallVector operands; + if (failed(rewriterImpl.remapValues(op->getLoc(), rewriter, + getTypeConverter(), op->getOperands(), + operands))) { + return failure(); + } return matchAndRewrite(op, operands, dialectRewriter); } @@ -1878,6 +2039,24 @@ /// remaining artifacts and complete the conversion. LogicalResult finalize(ConversionPatternRewriter &rewriter); + /// Legalize the types of converted block arguments. + LogicalResult + legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &rewriterImpl); + + /// Legalize an operation result that was marked as "erased". + LogicalResult + legalizeErasedResult(Operation *op, OpResult result, + ConversionPatternRewriterImpl &rewriterImpl); + + /// Legalize an operation result that was replaced with a value of a different + /// type. + LogicalResult + legalizeChangedResultType(Operation *op, OpResult result, Value newValue, + TypeConverter *replConverter, + ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &rewriterImpl); + /// The legalizer to use when converting operations. OperationLegalizer opLegalizer; @@ -1961,33 +2140,145 @@ LogicalResult OperationConverter::finalize(ConversionPatternRewriter &rewriter) { ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); - auto isOpDead = [&](Operation *op) { return rewriterImpl.isOpIgnored(op); }; - // Process the operations with changed results. - for (unsigned replIdx : rewriterImpl.operationsWithChangedResults) { + // Legalize converted block arguments. + if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl))) + return failure(); + + // Process requested operation replacements. + for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size(); + i != e; ++i) { + unsigned replIdx = rewriterImpl.operationsWithChangedResults[i]; auto &repl = *(rewriterImpl.replacements.begin() + replIdx); - for (auto it : llvm::zip(repl.first->getResults(), repl.second.newValues)) { - Value result = std::get<0>(it), newValue = std::get<1>(it); + for (OpResult result : repl.first->getResults()) { + Value newValue = rewriterImpl.mapping.lookupOrNull(result); // If the operation result was replaced with null, all of the uses of this // value should be replaced. - if (newValue) + if (!newValue) { + if (failed(legalizeErasedResult(repl.first, result, rewriterImpl))) + return failure(); + continue; + } + + // Otherwise, check to see if the type of the result changed. + if (result.getType() == newValue.getType()) continue; - auto liveUserIt = llvm::find_if_not(result.getUsers(), isOpDead); - if (liveUserIt != result.user_end()) { - InFlightDiagnostic diag = repl.first->emitError() - << "failed to legalize operation '" - << repl.first->getName() - << "' marked as erased"; - diag.attachNote(liveUserIt->getLoc()) - << "found live user of result #" - << result.cast().getResultNumber() << ": " << *liveUserIt; + // Legalize this result. + rewriter.setInsertionPoint(repl.first); + if (failed(legalizeChangedResultType(repl.first, result, newValue, + repl.second.converter, rewriter, + rewriterImpl))) return failure(); - } + + // Update the end iterator for this loop in the case it was updated + // when legalizing generated conversion operations. + e = rewriterImpl.operationsWithChangedResults.size(); + } + } + return success(); +} + +LogicalResult OperationConverter::legalizeConvertedArgumentTypes( + ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &rewriterImpl) { + // Functor used to check if all users of a value will be dead after + // conversion. + auto findLiveUser = [&](Value val) { + auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) { + return rewriterImpl.isOpIgnored(user); + }); + return liveUserIt == val.user_end() ? nullptr : *liveUserIt; + }; + + // Materialize any necessary conversions for converted block arguments that + // are still live. + size_t numCreatedOps = rewriterImpl.createdOps.size(); + if (failed(rewriterImpl.argConverter.materializeLiveConversions( + rewriterImpl.mapping, rewriter, findLiveUser))) + return failure(); + + // Legalize any newly created operatoins during argument materialization. + for (int i : llvm::seq(numCreatedOps, rewriterImpl.createdOps.size())) { + if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) { + return rewriterImpl.createdOps[i]->emitError() + << "failed to legalize conversion operation generated for block " + "argument that remained live after conversion"; + } + } + return success(); +} + +LogicalResult OperationConverter::legalizeErasedResult( + Operation *op, OpResult result, + ConversionPatternRewriterImpl &rewriterImpl) { + // If the operation result was replaced with null, all of the uses of this + // value should be replaced. + auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) { + return rewriterImpl.isOpIgnored(user); + }); + if (liveUserIt != result.user_end()) { + InFlightDiagnostic diag = op->emitError("failed to legalize operation '") + << op->getName() << "' marked as erased"; + diag.attachNote(liveUserIt->getLoc()) + << "found live user of result #" << result.getResultNumber() << ": " + << *liveUserIt; + return failure(); + } + return success(); +} + +LogicalResult OperationConverter::legalizeChangedResultType( + Operation *op, OpResult result, Value newValue, + TypeConverter *replConverter, ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &rewriterImpl) { + // Walk the users of this value to see if there are any live users that + // weren't replaced during conversion. + auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) { + return rewriterImpl.isOpIgnored(user); + }); + if (liveUserIt == result.user_end()) + return success(); + + // If the replacement has a type converter, attempt to materialize a + // conversion back to the original type. + if (!replConverter) { + // TODO: We should emit an error here, similarly to the case where the + // result is replaced with null. Unfortunately a lot of existing + // patterns rely on this behavior, so until those patterns are updated + // we keep the legacy behavior here of just forwarding the new value. + return success(); + } + + // Track the number of created operations so that new ones can be legalized. + size_t numCreatedOps = rewriterImpl.createdOps.size(); + + // Materialize a conversion for this live result value. + Type resultType = result.getType(); + Value convertedValue = replConverter->materializeSourceConversion( + rewriter, op->getLoc(), resultType, newValue); + if (!convertedValue) { + InFlightDiagnostic diag = op->emitError() + << "failed to materialize conversion for result #" + << result.getResultNumber() << " of operation '" + << op->getName() + << "' that remained live after conversion"; + diag.attachNote(liveUserIt->getLoc()) + << "see existing live user here: " << *liveUserIt; + return failure(); + } + + // Legalize all of the newly created conversion operations. + for (int i : llvm::seq(numCreatedOps, rewriterImpl.createdOps.size())) { + if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) { + return op->emitError("failed to legalize conversion operation generated ") + << "for result #" << result.getResultNumber() << " of operation '" + << op->getName() << "' that remained live after conversion"; } } + rewriterImpl.mapping.map(result, convertedValue); return success(); } @@ -2136,11 +2427,11 @@ return success(); } -Value TypeConverter::materializeConversion(PatternRewriter &rewriter, - Location loc, Type resultType, - ValueRange inputs) { +Value TypeConverter::materializeConversion( + MutableArrayRef materializations, + OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) { for (MaterializationCallbackFn &fn : llvm::reverse(materializations)) - if (Optional result = fn(rewriter, resultType, inputs, loc)) + if (Optional result = fn(builder, resultType, inputs, loc)) return result.getValue(); return nullptr; } diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -75,15 +75,3 @@ %0 = rsqrt %arg0 : vector<4x3xf32> std.return } - -// ----- - -// This should not crash. The first operation cannot be converted, so the -// second should not match. This attempts to convert `return` to `llvm.return` -// and complains about non-LLVM types. -func @unknown_source() -> i32 { - %0 = "foo"() : () -> i32 - %1 = addi %0, %0 : i32 - // expected-error@+1 {{must be LLVM dialect type}} - return %1 : i32 -} diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir @@ -57,9 +57,12 @@ // CHECK: [[CONST3:%.*]] = spv.constant 0 : i32 // CHECK: [[ARG3PTR:%.*]] = spv.AccessChain [[ADDRESSARG3]]{{\[}}[[CONST3]] // CHECK: [[ARG3:%.*]] = spv.Load "StorageBuffer" [[ARG3PTR]] - // CHECK: [[ARG2:%.*]] = spv._address_of [[VAR2]] - // CHECK: [[ARG1:%.*]] = spv._address_of [[VAR1]] - // CHECK: [[ARG0:%.*]] = spv._address_of [[VAR0]] + // CHECK: [[ADDRESSARG2:%.*]] = spv._address_of [[VAR2]] + // CHECK: [[ARG2:%.*]] = spv.Bitcast [[ADDRESSARG2]] + // CHECK: [[ADDRESSARG1:%.*]] = spv._address_of [[VAR1]] + // CHECK: [[ARG1:%.*]] = spv.Bitcast [[ADDRESSARG1]] + // CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]] + // CHECK: [[ARG0:%.*]] = spv.Bitcast [[ADDRESSARG0]] %0 = spv._address_of @__builtin_var_WorkgroupId__ : !spv.ptr, Input> %1 = spv.Load "Input" %0 : vector<3xi32> %2 = spv.CompositeExtract %1[0 : i32] : vector<3xi32> diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir @@ -0,0 +1,64 @@ +// RUN: mlir-opt %s -test-legalize-type-conversion -allow-unregistered-dialect -split-input-file -verify-diagnostics | FileCheck %s + +// expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}} +func @test_invalid_arg_materialization(%arg0: i16) { + // expected-note@below {{see existing live user here}} + "foo.return"(%arg0) : (i16) -> () +} + +// ----- + +// expected-error@below {{failed to legalize conversion operation generated for block argument}} +func @test_invalid_arg_illegal_materialization(%arg0: i32) { + "foo.return"(%arg0) : (i32) -> () +} + +// ----- + +// CHECK-LABEL: func @test_valid_arg_materialization +func @test_valid_arg_materialization(%arg0: i64) { + // CHECK: %[[ARG:.*]] = "test.type_producer" + // CHECK: "foo.return"(%[[ARG]]) : (i64) + + "foo.return"(%arg0) : (i64) -> () +} + +// ----- + +func @test_invalid_result_materialization() { + // expected-error@below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}} + %result = "test.type_producer"() : () -> f16 + + // expected-note@below {{see existing live user here}} + "foo.return"(%result) : (f16) -> () +} + +// ----- + +func @test_invalid_result_materialization() { + // expected-error@below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}} + %result = "test.type_producer"() : () -> f16 + + // expected-note@below {{see existing live user here}} + "foo.return"(%result) : (f16) -> () +} + +// ----- + +func @test_invalid_result_legalization() { + // expected-error@below {{failed to legalize conversion operation generated for result #0 of operation 'test.type_producer' that remained live after conversion}} + %result = "test.type_producer"() : () -> i16 + "foo.return"(%result) : (i16) -> () +} + +// ----- + +// CHECK-LABEL: func @test_valid_result_legalization +func @test_valid_result_legalization() { + // CHECK: %[[RESULT:.*]] = "test.type_producer"() : () -> f64 + // CHECK: %[[CAST:.*]] = "test.cast"(%[[RESULT]]) : (f64) -> f32 + // CHECK: "foo.return"(%[[CAST]]) : (f32) + + %result = "test.type_producer"() : () -> f32 + "foo.return"(%result) : (f32) -> () +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -485,8 +485,9 @@ using TypeConverter::TypeConverter; TestTypeConverter() { addConversion(convertType); - addMaterialization(materializeCast); - addMaterialization(materializeOneToOneCast); + addArgumentMaterialization(materializeCast); + addArgumentMaterialization(materializeOneToOneCast); + addSourceMaterialization(materializeCast); } static LogicalResult convertType(Type t, SmallVectorImpl &results) { @@ -519,21 +520,21 @@ /// Hook for materializing a conversion. This is necessary because we generate /// 1->N type mappings. - static Optional materializeCast(PatternRewriter &rewriter, + static Optional materializeCast(OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) { if (inputs.size() == 1) return inputs[0]; - return rewriter.create(loc, resultType, inputs).getResult(); + return builder.create(loc, resultType, inputs).getResult(); } /// Materialize the cast for one-to-one conversion from i64 to f64. - static Optional materializeOneToOneCast(PatternRewriter &rewriter, + static Optional materializeOneToOneCast(OpBuilder &builder, IntegerType resultType, ValueRange inputs, Location loc) { if (resultType.getWidth() == 42 && inputs.size() == 1) - return rewriter.create(loc, resultType, inputs).getResult(); + return builder.create(loc, resultType, inputs).getResult(); return llvm::None; } }; @@ -742,6 +743,102 @@ }; } // end anonymous namespace +//===----------------------------------------------------------------------===// +// Test type conversions +//===----------------------------------------------------------------------===// + +namespace { +struct TestTypeConversionProducer + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(TestTypeProducerOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + Type resultType = op.getType(); + if (resultType.isa()) + resultType = rewriter.getF64Type(); + else if (resultType.isInteger(16)) + resultType = rewriter.getIntegerType(64); + else + return failure(); + + rewriter.replaceOpWithNewOp(op, resultType); + return success(); + } +}; + +struct TestTypeConversionDriver + : public PassWrapper> { + void runOnOperation() override { + // Initialize the type converter. + TypeConverter converter; + + /// Add the legal set of type conversions. + converter.addConversion([](Type type) -> Type { + // Treat F64 as legal. + if (type.isF64()) + return type; + // Allow converting BF16/F16/F32 to F64. + if (type.isBF16() || type.isF16() || type.isF32()) + return FloatType::getF64(type.getContext()); + // Otherwise, the type is illegal. + return nullptr; + }); + converter.addConversion([](IntegerType type, SmallVectorImpl &) { + // Drop all integer types. + return success(); + }); + + /// Add the legal set of type materializations. + converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> Value { + // Allow casting from F64 back to F32. + if (!resultType.isF16() && inputs.size() == 1 && + inputs[0].getType().isF64()) + return builder.create(loc, resultType, inputs).getResult(); + // Allow producing an i32 or i64 from nothing. + if ((resultType.isInteger(32) || resultType.isInteger(64)) && + inputs.empty()) + return builder.create(loc, resultType); + // Allow producing an i64 from an integer. + if (resultType.isa() && inputs.size() == 1 && + inputs[0].getType().isa()) + return builder.create(loc, resultType, inputs).getResult(); + // Otherwise, fail. + return nullptr; + }); + + // Initialize the conversion target. + mlir::ConversionTarget target(getContext()); + target.addDynamicallyLegalOp([](TestTypeProducerOp op) { + return op.getType().isF64() || op.getType().isInteger(64); + }); + target.addDynamicallyLegalOp([&](FuncOp op) { + return converter.isSignatureLegal(op.getType()) && + converter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp([&](TestCastOp op) { + // Allow casts from F64 to F32. + return (*op.operand_type_begin()).isF64() && op.getType().isF32(); + }); + + // Initialize the set of rewrite patterns. + OwningRewritePatternList patterns; + patterns.insert(converter, &getContext()); + mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), + converter); + + if (failed(applyPartialConversion(getOperation(), target, patterns))) + signalPassFailure(); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// PassRegistration +//===----------------------------------------------------------------------===// + namespace mlir { void registerPatternsTestPass() { PassRegistration("test-return-type", @@ -766,5 +863,9 @@ PassRegistration( "test-legalize-unknown-root-patterns", "Test public remapped value mechanism in ConversionPatternRewriter"); + + PassRegistration( + "test-legalize-type-conversion", + "Test various type conversion functionalities in DialectConversion"); } } // namespace mlir