diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -122,11 +122,6 @@ /// pointers to memref descriptors for arguments. LLVM::LLVMType convertFunctionTypeCWrapper(FunctionType type); - /// Creates descriptor structs from individual values constituting them. - Operation *materializeConversion(PatternRewriter &rewriter, Type type, - ArrayRef values, - Location loc) override; - /// Gets the LLVM representation of the index type. The returned type is an /// integer type with the size configured for this type converter. LLVM::LLVMType getIndexType(); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -113,6 +113,13 @@ registerConversion(wrapCallback(std::forward(callback))); } + using MaterializationCallbackFn = std::function( + PatternRewriter &, Type, ValueRange, Location)>; + + void addMaterialization(MaterializationCallbackFn &&callback) { + materializations.emplace_back(std::move(callback)); + } + /// Convert the given type. This function should return failure if no valid /// conversion exists, success otherwise. If the new set of types is empty, /// the type is removed and any usages of the existing value are expected to @@ -154,12 +161,9 @@ /// 'inputs' as operands. This hook must be overridden when a type conversion /// results in more than one type, or if a type conversion may persist after /// the conversion has finished. - virtual Operation *materializeConversion(PatternRewriter &rewriter, - Type resultType, - ArrayRef inputs, - Location loc) { - llvm_unreachable("expected 'materializeConversion' to be overridden"); - } + // TODO(zinenko): Fix doc. + Value materializeConversion(PatternRewriter &rewriter, Location loc, + Type resultType, ValueRange inputs); private: /// The signature of the callback used to convert a type. If the new set of @@ -206,6 +210,8 @@ /// The set of registered conversion functions. SmallVector conversions; + + SmallVector materializations; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -150,6 +150,20 @@ // LLVMType is legal, so add a pass-through conversion. addConversion([](LLVM::LLVMType type) { return type; }); + + // Materialization for memrefs creates descriptor structs from individual + // values constituting them. + addMaterialization([&](PatternRewriter &rewriter, Type resultType, + ValueRange inputs, Location loc) -> Optional { + if (inputs.size() == 1) + return llvm::None; + if (auto unrankedMemRefType = resultType.dyn_cast()) + return UnrankedMemRefDescriptor::pack(rewriter, loc, *this, + unrankedMemRefType, inputs); + if (auto memRefType = resultType.dyn_cast()) + return MemRefDescriptor::pack(rewriter, loc, *this, memRefType, inputs); + return llvm::None; + }); } /// Returns the MLIR context. @@ -297,22 +311,6 @@ return LLVM::LLVMType::getFunctionTy(resultType, inputs, false); } -/// Creates descriptor structs from individual values constituting them. -Operation *LLVMTypeConverter::materializeConversion(PatternRewriter &rewriter, - Type type, - ArrayRef values, - Location loc) { - if (auto unrankedMemRefType = type.dyn_cast()) - return UnrankedMemRefDescriptor::pack(rewriter, loc, *this, - unrankedMemRefType, values) - .getDefiningOp(); - - auto memRefType = type.dyn_cast(); - assert(memRefType && "1->N conversion is only supported for memrefs"); - return MemRefDescriptor::pack(rewriter, loc, *this, memRefType, values) - .getDefiningOp(); -} - // Convert a MemRef to an LLVM type. The result is a MemRef descriptor which // contains: // 1. the pointer to the data buffer, followed by diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -305,27 +305,20 @@ // persist in the IR after conversion. if (!origArg.use_empty()) { rewriter.setInsertionPointToStart(newBlock); - auto *newOp = typeConverter->materializeConversion( - rewriter, origArg.getType(), llvm::None, loc); - origArg.replaceAllUsesWith(newOp->getResult(0)); + Value newArg = typeConverter->materializeConversion( + rewriter, loc, origArg.getType(), llvm::None); + assert(newArg && + "Couldn't materialize a block argument after 1->0 conversion"); + origArg.replaceAllUsesWith(newArg); } continue; } - // If mapping is 1-1, replace the remaining uses and drop the cast - // operation. - // FIXME(riverriddle) This should check that the result type and operand - // type are the same, otherwise it should force a conversion to be - // materialized. - if (argInfo->newArgSize == 1) { - origArg.replaceAllUsesWith( - mapping.lookupOrDefault(newBlock->getArgument(argInfo->newArgIdx))); - continue; - } - - // Otherwise this is a 1->N value mapping. + // Otherwise this is a 1->1+ value mapping. Value castValue = argInfo->castValue; - assert(argInfo->newArgSize > 1 && castValue && "expected 1->N mapping"); + assert(argInfo->newArgSize >= 1); + assert(castValue); + assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping"); // If the argument is still used, replace it with the generated cast. if (!origArg.use_empty()) @@ -333,7 +326,7 @@ // If all users of the cast were removed, we can drop it. Otherwise, keep // the operation alive and let the user handle any remaining usages. - if (castValue.use_empty()) + if (castValue.use_empty() && castValue.getDefiningOp()) castValue.getDefiningOp()->erase(); } } @@ -389,22 +382,22 @@ continue; } - // If this is a 1->1 mapping, then map the argument directly. - if (inputMap->size == 1) { - mapping.map(origArg, newArgs[inputMap->inputNo]); - info.argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size); - continue; - } - - // Otherwise, this is a 1->N mapping. Call into the provided type converter - // to pack the new values. + // Otherwise, this is a 1->1+ mapping. Call into the provided type converter + // to pack the new values. For 1->1 mappings, if there is no materialization + // provided, use the argument directly instead. auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); - Operation *cast = typeConverter->materializeConversion( - rewriter, origArg.getType(), replArgs, loc); - assert(cast->getNumResults() == 1); - mapping.map(origArg, cast->getResult(0)); + Value newArg = typeConverter + ? typeConverter->materializeConversion( + rewriter, loc, origArg.getType(), replArgs) + : Value(); + if (!newArg) { + assert(replArgs.size() == 1 && + "couldn't materialize the result of 1->N conversion"); + newArg = replArgs.front(); + } + mapping.map(origArg, newArg); info.argInfo[i] = - ConvertedArgInfo(inputMap->inputNo, inputMap->size, cast->getResult(0)); + ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg); } // Remove the original block from the region and return the new one. @@ -1755,6 +1748,15 @@ return success(); } +Value TypeConverter::materializeConversion(PatternRewriter &rewriter, + Location loc, Type resultType, + ValueRange inputs) { + for (MaterializationCallbackFn &fn : llvm::reverse(materializations)) + if (Optional result = fn(rewriter, resultType, inputs, loc)) + return result.getValue(); + return nullptr; +} + /// Create a default conversion pattern that rewrites the type signature of a /// FuncOp. namespace { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -448,7 +448,10 @@ namespace { struct TestTypeConverter : public TypeConverter { using TypeConverter::TypeConverter; - TestTypeConverter() { addConversion(convertType); } + TestTypeConverter() { + addConversion(convertType); + addMaterialization(materializeCast); + } static LogicalResult convertType(Type t, SmallVectorImpl &results) { // Drop I16 types. @@ -472,12 +475,14 @@ return success(); } - /// Override the hook to materialize a conversion. This is necessary because - /// we generate 1->N type mappings. - Operation *materializeConversion(PatternRewriter &rewriter, Type resultType, - ArrayRef inputs, - Location loc) override { - return rewriter.create(loc, resultType, inputs); + /// Hook for materializing a conversion. This is necessary because we generate + /// 1->N type mappings. + static Optional materializeCast(PatternRewriter &rewriter, + Type resultType, ValueRange inputs, + Location loc) { + if (inputs.size() == 1) + return inputs[0]; + return rewriter.create(loc, resultType, inputs).getResult(); } };