diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -72,16 +72,16 @@ /// used if the new types are not intended to remap an existing input. void addInputs(ArrayRef types); - /// Remap an input of the original signature with a range of types in the - /// new signature. - void remapInput(unsigned origInputNo, unsigned newInputNo, - unsigned newInputCount = 1); - /// Remap an input of the original signature to another `replacement` /// value. This drops the original argument. void remapInput(unsigned origInputNo, Value replacement); private: + /// Remap an input of the original signature with a range of types in the + /// new signature. + void remapInput(unsigned origInputNo, unsigned newInputNo, + unsigned newInputCount = 1); + /// The remapping information for each of the original arguments. SmallVector, 4> remappedInputs; @@ -149,16 +149,29 @@ /// Return true if the given type is legal for this type converter, i.e. the /// type converts to itself. bool isLegal(Type type); + /// Return true if all of the given types are legal for this type converter. + template + std::enable_if_t::value && + !std::is_convertible::value, + bool> + isLegal(RangeT &&range) { + return llvm::all_of(range, [this](Type type) { return isLegal(type); }); + } + /// Return true if the given operation has legal operand and result types. + bool isLegal(Operation *op); /// Return true if the inputs and outputs of the given function type are /// legal. - bool isSignatureLegal(FunctionType funcType); + bool isSignatureLegal(FunctionType ty); /// This method allows for converting a specific argument of a signature. It /// takes as inputs the original argument input number, type. /// On success, it populates 'result' with any new mappings. LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result); + LogicalResult convertSignatureArgs(TypeRange types, + SignatureConversion &result, + unsigned origInputOffset = 0); /// This function converts the type signature of the given block, by invoking /// 'convertSignatureArg' for each argument. This function should return a @@ -214,6 +227,8 @@ /// Register a type conversion. void registerConversion(ConversionCallbackFn callback) { conversions.emplace_back(std::move(callback)); + cachedDirectConversions.clear(); + cachedMultiConversions.clear(); } /// Generate a wrapper for the given materialization callback. The callback @@ -240,6 +255,13 @@ /// The list of registered materialization functions. SmallVector materializations; + + /// A set of cached conversions to avoid recomputing in the common case. + /// Direct 1-1 conversions are the most common, so this cache stores the + /// successful 1-1 conversions as well as all failed conversions. + DenseMap cachedDirectConversions; + /// This cache stores the successful 1->N conversions, where N != 1. + DenseMap> cachedMultiConversions; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp @@ -120,10 +120,8 @@ target.addLegalDialect(); // Mark all Linalg operations illegal as long as they work on tensors. - auto isIllegalType = [&](Type type) { return !converter.isLegal(type); }; auto isLegalOperation = [&](Operation *op) { - return llvm::none_of(op->getOperandTypes(), isIllegalType) && - llvm::none_of(op->getResultTypes(), isIllegalType); + return converter.isLegal(op); }; target.addDynamicallyLegalDialect( Optional( @@ -131,7 +129,7 @@ // Mark Standard Return operations illegal as long as one operand is tensor. target.addDynamicallyLegalOp([&](mlir::ReturnOp returnOp) { - return llvm::none_of(returnOp.getOperandTypes(), isIllegalType); + return converter.isLegal(returnOp.getOperandTypes()); }); // Mark the function operation illegal as long as an argument is tensor. diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -1742,11 +1742,35 @@ /// This hooks allows for converting a type. LogicalResult TypeConverter::convertType(Type t, SmallVectorImpl &results) { + auto existingIt = cachedDirectConversions.find(t); + if (existingIt != cachedDirectConversions.end()) { + if (existingIt->second) + results.push_back(existingIt->second); + return success(existingIt->second != nullptr); + } + auto multiIt = cachedMultiConversions.find(t); + if (multiIt != cachedMultiConversions.end()) { + results.append(multiIt->second.begin(), multiIt->second.end()); + return success(); + } + // Walk the added converters in reverse order to apply the most recently // registered first. - for (ConversionCallbackFn &converter : llvm::reverse(conversions)) - if (Optional result = converter(t, results)) - return *result; + size_t currentCount = results.size(); + for (ConversionCallbackFn &converter : llvm::reverse(conversions)) { + if (Optional result = converter(t, results)) { + if (!succeeded(*result)) { + cachedDirectConversions.try_emplace(t, nullptr); + return failure(); + } + auto newTypes = ArrayRef(results).drop_front(currentCount); + if (newTypes.size() == 1) + cachedDirectConversions.try_emplace(t, newTypes.front()); + else + cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes)); + return success(); + } + } return failure(); } @@ -1775,18 +1799,16 @@ /// Return true if the given type is legal for this type converter, i.e. the /// type converts to itself. -bool TypeConverter::isLegal(Type type) { - SmallVector results; - return succeeded(convertType(type, results)) && results.size() == 1 && - results.front() == type; +bool TypeConverter::isLegal(Type type) { return convertType(type) == type; } +/// Return true if the given operation has legal operand and result types. +bool TypeConverter::isLegal(Operation *op) { + return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes()); } /// Return true if the inputs and outputs of the given function type are /// legal. -bool TypeConverter::isSignatureLegal(FunctionType funcType) { - return llvm::all_of( - llvm::concat(funcType.getInputs(), funcType.getResults()), - [this](Type type) { return isLegal(type); }); +bool TypeConverter::isSignatureLegal(FunctionType ty) { + return isLegal(llvm::concat(ty.getInputs(), ty.getResults())); } /// This hook allows for converting a specific argument of a signature. @@ -1805,6 +1827,14 @@ result.addInputs(inputNo, convertedTypes); return success(); } +LogicalResult TypeConverter::convertSignatureArgs(TypeRange types, + SignatureConversion &result, + unsigned origInputOffset) { + for (unsigned i = 0, e = types.size(); i != e; ++i) + if (failed(convertSignatureArg(origInputOffset + i, types[i], result))) + return failure(); + return success(); +} Value TypeConverter::materializeConversion(PatternRewriter &rewriter, Location loc, Type resultType, @@ -1815,6 +1845,17 @@ return nullptr; } +/// This function converts the type signature of the given block, by invoking +/// 'convertSignatureArg' for each argument. This function should return a valid +/// conversion for the signature on success, None otherwise. +auto TypeConverter::convertBlockSignature(Block *block) + -> Optional { + SignatureConversion conversion(block->getNumArguments()); + if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion))) + return llvm::None; + return conversion; +} + /// Create a default conversion pattern that rewrites the type signature of a /// FuncOp. namespace { @@ -1828,15 +1869,11 @@ ConversionPatternRewriter &rewriter) const override { FunctionType type = funcOp.getType(); - // Convert the original function arguments. + // Convert the original function types. TypeConverter::SignatureConversion result(type.getNumInputs()); - for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) - if (failed(converter.convertSignatureArg(i, type.getInput(i), result))) - return failure(); - - // Convert the original function results. SmallVector convertedResults; - if (failed(converter.convertTypes(type.getResults(), convertedResults))) + if (failed(converter.convertSignatureArgs(type.getInputs(), result)) || + failed(converter.convertTypes(type.getResults(), convertedResults))) return failure(); // Update the function signature in-place. @@ -1859,19 +1896,6 @@ patterns.insert(ctx, converter); } -/// This function converts the type signature of the given block, by invoking -/// 'convertSignatureArg' for each argument. This function should return a valid -/// conversion for the signature on success, None otherwise. -auto TypeConverter::convertBlockSignature(Block *block) - -> Optional { - SignatureConversion conversion(block->getNumArguments()); - for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) - if (failed(convertSignatureArg(i, block->getArgument(i).getType(), - conversion))) - return llvm::None; - return conversion; -} - //===----------------------------------------------------------------------===// // ConversionTarget //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -314,10 +314,9 @@ // Convert the original entry arguments. TypeConverter::SignatureConversion result(entry->getNumArguments()); - for (unsigned i = 0, e = entry->getNumArguments(); i != e; ++i) - if (failed(converter.convertSignatureArg( - i, entry->getArgument(i).getType(), result))) - return failure(); + if (failed( + converter.convertSignatureArgs(entry->getArgumentTypes(), result))) + return failure(); // Convert the region signature and just drop the operation. rewriter.applySignatureConversion(®ion, result); diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -124,10 +124,8 @@ target.addLegalDialect(); // Mark all Linalg operations illegal as long as they work on tensors. - auto isIllegalType = [&](Type type) { return !converter.isLegal(type); }; auto isLegalOperation = [&](Operation *op) { - return llvm::none_of(op->getOperandTypes(), isIllegalType) && - llvm::none_of(op->getResultTypes(), isIllegalType); + return converter.isLegal(op); }; target.addDynamicallyLegalDialect( Optional( @@ -135,14 +133,12 @@ // Mark Standard Return operations illegal as long as one operand is tensor. target.addDynamicallyLegalOp([&](mlir::ReturnOp returnOp) { - return llvm::none_of(returnOp.getOperandTypes(), isIllegalType); + return converter.isLegal(returnOp.getOperandTypes()); }); // Mark Standard Call Operation illegal as long as it operates on tensor. - target.addDynamicallyLegalOp([&](mlir::CallOp callOp) { - return llvm::none_of(callOp.getOperandTypes(), isIllegalType) && - llvm::none_of(callOp.getResultTypes(), isIllegalType); - }); + target.addDynamicallyLegalOp( + [&](mlir::CallOp callOp) { return converter.isLegal(callOp); }); // Mark the function whose arguments are in tensor-type illegal. target.addDynamicallyLegalOp([&](FuncOp funcOp) {