diff --git a/mlir/include/mlir/Transforms/Bufferize.h b/mlir/include/mlir/Transforms/Bufferize.h --- a/mlir/include/mlir/Transforms/Bufferize.h +++ b/mlir/include/mlir/Transforms/Bufferize.h @@ -18,10 +18,8 @@ #ifndef MLIR_TRANSFORMS_BUFFERIZE_H #define MLIR_TRANSFORMS_BUFFERIZE_H -#include "mlir/Analysis/Liveness.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/Dominance.h" #include "mlir/IR/Function.h" #include "mlir/IR/Operation.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp --- a/mlir/lib/Transforms/BufferPlacement.cpp +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -56,10 +56,13 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" +#include "mlir/Analysis/Liveness.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/Operation.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/Bufferize.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/SetOperations.h" @@ -809,245 +812,6 @@ } // end anonymous namespace -//===----------------------------------------------------------------------===// -// BufferAssignmentTypeConverter -//===----------------------------------------------------------------------===// - -/// Registers conversions into BufferAssignmentTypeConverter -BufferAssignmentTypeConverter::BufferAssignmentTypeConverter() { - // Keep all types unchanged. - addConversion([](Type type) { return type; }); - // Convert RankedTensorType to MemRefType. - addConversion([](RankedTensorType type) { - return (Type)MemRefType::get(type.getShape(), type.getElementType()); - }); - // Convert UnrankedTensorType to UnrankedMemRefType. - addConversion([](UnrankedTensorType type) { - return (Type)UnrankedMemRefType::get(type.getElementType(), 0); - }); -} - -/// This method tries to decompose a value of a certain type using provided -/// decompose callback functions. If it is unable to do so, the original value -/// is returned. -void BufferAssignmentTypeConverter::tryDecomposeValue( - OpBuilder &builder, Location loc, Type type, Value value, - SmallVectorImpl &results) { - for (auto conversion : decomposeValueConversions) - if (conversion(builder, loc, type, value, results) != llvm::None) - return; - results.push_back(value); -} - -/// This method tries to decompose a type using provided decompose callback -/// functions. If it is unable to do so, the original type is returned. -void BufferAssignmentTypeConverter::tryDecomposeType( - Type type, SmallVectorImpl &types) { - for (auto conversion : decomposeTypeConversions) - if (conversion(type, types) != llvm::None) - return; - types.push_back(type); -} - -/// This method returns ResultConversionKind for the input type. -BufferAssignmentTypeConverter::ResultConversionKind -BufferAssignmentTypeConverter::getResultConversionKind(Type origin, - Type converted) { - for (auto conversion : resultTypeConversions) { - auto res = conversion(origin, converted); - if (res != llvm::None) - return res.getValue(); - } - return KeepAsFunctionResult; -} - -//===----------------------------------------------------------------------===// -// BufferAssignmentFuncOpConverter -//===----------------------------------------------------------------------===// - -/// Performs the actual function signature rewriting step. -LogicalResult BufferAssignmentFuncOpConverter::matchAndRewrite( - mlir::FuncOp funcOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - auto funcType = funcOp.getType(); - - // Convert function arguments using the provided TypeConverter. - TypeConverter::SignatureConversion conversion(funcType.getNumInputs()); - for (auto argType : llvm::enumerate(funcType.getInputs())) { - SmallVector decomposedTypes, convertedTypes; - converter.tryDecomposeType(argType.value(), decomposedTypes); - converter.convertTypes(decomposedTypes, convertedTypes); - conversion.addInputs(argType.index(), convertedTypes); - } - - // Convert the result types of the function. - SmallVector newResultTypes; - newResultTypes.reserve(funcOp.getNumResults()); - for (Type resultType : funcType.getResults()) { - SmallVector originTypes; - converter.tryDecomposeType(resultType, originTypes); - for (auto origin : originTypes) { - Type converted = converter.convertType(origin); - auto kind = converter.getResultConversionKind(origin, converted); - if (kind == BufferAssignmentTypeConverter::AppendToArgumentsList) - conversion.addInputs(converted); - else - // kind = BufferAssignmentTypeConverter::KeepAsFunctionResult - newResultTypes.push_back(converted); - } - } - - if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), converter, - &conversion))) - return failure(); - - // Update the signature of the function. - rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), - newResultTypes)); - }); - return success(); -} - -//===----------------------------------------------------------------------===// -// BufferAssignmentCallOpConverter -//===----------------------------------------------------------------------===// - -namespace { -// This class represents a mapping from a result to a list of values and some -// results that have not yet constructed. Instead, the indices of these -// results in the operation that will be constructed are known. They will be -// replaced with the actual values when they are available. The order of -// adding to this mapping is important. -class CallOpResultMapping { -public: - CallOpResultMapping() { order = 0; }; - - /// Add an available value to the mapping. - void addMapping(Value value) { toValuesMapping.push_back({order++, value}); } - - /// Add the index of unavailble result value to the mapping. - void addMapping(unsigned index) { - toIndicesMapping.push_back({order++, index}); - } - - /// This method returns the mapping values list. The unknown result values - /// that only their indicies are available are replaced with their values. - void getMappingValues(ValueRange valuesToReplaceIndices, - SmallVectorImpl &values) { - // Append available values to the list. - SmallVector, 2> res(toValuesMapping.begin(), - toValuesMapping.end()); - // Replace the indices with the actual values. - llvm::for_each( - toIndicesMapping, [&](const std::pair &entry) { - assert(entry.second < valuesToReplaceIndices.size() && - "The value index is out of range."); - res.push_back({entry.first, valuesToReplaceIndices[entry.second]}); - }); - // Sort the values based on their adding orders. - llvm::sort(res, [](const std::pair &v1, - const std::pair &v2) { - return v1.first < v2.first; - }); - // Fill the values. - llvm::for_each(res, [&](const std::pair &entry) { - values.push_back(entry.second); - }); - } - -private: - /// Keeping the inserting order of mapping values. - int order; - - /// Containing the mapping values with their inserting orders. - SmallVector, 2> toValuesMapping; - - /// Containing the indices of result values with their inserting orders. - SmallVector, 2> toIndicesMapping; -}; -} // namespace - -/// Performs the actual rewriting step. -LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite( - CallOp callOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - - Location loc = callOp.getLoc(); - OpBuilder builder(callOp); - SmallVector newOperands; - - // TODO: if the CallOp references a FuncOp that only has a declaration (e.g. - // to an externally defined symbol like an external library calls), only - // convert if some special attribute is set. - // This will allow more control of interop across ABI boundaries. - - // Create the operands list of the new `CallOp`. It unpacks the decomposable - // values if a decompose callback function has been provided by the user. - for (auto operand : operands) { - SmallVector values; - this->converter.tryDecomposeValue(builder, loc, operand.getType(), operand, - values); - newOperands.append(values.begin(), values.end()); - } - - // Create the new result types for the new `CallOp` and a mapping from the old - // result to new value(s). - SmallVector newResultTypes; - SmallVector mappings; - mappings.resize(callOp.getNumResults()); - for (auto result : llvm::enumerate(callOp.getResults())) { - SmallVector originTypes; - converter.tryDecomposeType(result.value().getType(), originTypes); - auto &resultMapping = mappings[result.index()]; - for (Type origin : originTypes) { - Type converted = converter.convertType(origin); - auto kind = converter.getResultConversionKind(origin, converted); - if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult) { - newResultTypes.push_back(converted); - // The result value is not yet available. Its index is kept and it is - // replaced with the actual value of the new `CallOp` later. - resultMapping.addMapping(newResultTypes.size() - 1); - } else { - // kind = BufferAssignmentTypeConverter::AppendToArgumentsList - MemRefType memref = converted.dyn_cast(); - if (!memref) - return callOp.emitError("Cannot allocate for a non-Memref type"); - Value alloc = rewriter.create(loc, memref); - newOperands.push_back(alloc); - resultMapping.addMapping(alloc); - } - } - } - - CallOp newCallOp = rewriter.create(loc, callOp.getCallee(), - newResultTypes, newOperands); - - // Build a replacing value for each result to replace its uses. If a result - // has multiple mapping values, it needs to be packed to a single value. - OpBuilder nextBuilder(callOp.getOperation()->getNextNode()); - SmallVector replacedValues; - replacedValues.reserve(callOp.getNumResults()); - for (unsigned i = 0, e = callOp.getNumResults(); i < e; ++i) { - SmallVector valuesToPack; - mappings[i].getMappingValues(newCallOp.getResults(), valuesToPack); - if (valuesToPack.empty()) { - // No replacement is required. - replacedValues.push_back(nullptr); - } else if (valuesToPack.size() == 1) { - replacedValues.push_back(valuesToPack.front()); - } else { - // Values need to be packed using callback function. The same callback - // that is used for materializeArgumentConversion is used for packing. - Value packed = converter.materializeArgumentConversion( - nextBuilder, loc, callOp.getType(i), valuesToPack); - replacedValues.push_back(packed); - } - } - rewriter.replaceOp(callOp, replacedValues); - return success(); -} - //===----------------------------------------------------------------------===// // BufferPlacementPass construction //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/Bufferize.cpp @@ -0,0 +1,251 @@ +//===- Bufferize.cpp - Bufferization utilities ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Bufferize.h" +#include "mlir/IR/Operation.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// BufferAssignmentTypeConverter +//===----------------------------------------------------------------------===// + +/// Registers conversions into BufferAssignmentTypeConverter +BufferAssignmentTypeConverter::BufferAssignmentTypeConverter() { + // Keep all types unchanged. + addConversion([](Type type) { return type; }); + // Convert RankedTensorType to MemRefType. + addConversion([](RankedTensorType type) { + return (Type)MemRefType::get(type.getShape(), type.getElementType()); + }); + // Convert UnrankedTensorType to UnrankedMemRefType. + addConversion([](UnrankedTensorType type) { + return (Type)UnrankedMemRefType::get(type.getElementType(), 0); + }); +} + +/// This method tries to decompose a value of a certain type using provided +/// decompose callback functions. If it is unable to do so, the original value +/// is returned. +void BufferAssignmentTypeConverter::tryDecomposeValue( + OpBuilder &builder, Location loc, Type type, Value value, + SmallVectorImpl &results) { + for (auto conversion : decomposeValueConversions) + if (conversion(builder, loc, type, value, results) != llvm::None) + return; + results.push_back(value); +} + +/// This method tries to decompose a type using provided decompose callback +/// functions. If it is unable to do so, the original type is returned. +void BufferAssignmentTypeConverter::tryDecomposeType( + Type type, SmallVectorImpl &types) { + for (auto conversion : decomposeTypeConversions) + if (conversion(type, types) != llvm::None) + return; + types.push_back(type); +} + +/// This method returns ResultConversionKind for the input type. +BufferAssignmentTypeConverter::ResultConversionKind +BufferAssignmentTypeConverter::getResultConversionKind(Type origin, + Type converted) { + for (auto conversion : resultTypeConversions) { + auto res = conversion(origin, converted); + if (res != llvm::None) + return res.getValue(); + } + return KeepAsFunctionResult; +} + +//===----------------------------------------------------------------------===// +// BufferAssignmentFuncOpConverter +//===----------------------------------------------------------------------===// + +/// Performs the actual function signature rewriting step. +LogicalResult BufferAssignmentFuncOpConverter::matchAndRewrite( + mlir::FuncOp funcOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + auto funcType = funcOp.getType(); + + // Convert function arguments using the provided TypeConverter. + TypeConverter::SignatureConversion conversion(funcType.getNumInputs()); + for (auto argType : llvm::enumerate(funcType.getInputs())) { + SmallVector decomposedTypes, convertedTypes; + converter.tryDecomposeType(argType.value(), decomposedTypes); + converter.convertTypes(decomposedTypes, convertedTypes); + conversion.addInputs(argType.index(), convertedTypes); + } + + // Convert the result types of the function. + SmallVector newResultTypes; + newResultTypes.reserve(funcOp.getNumResults()); + for (Type resultType : funcType.getResults()) { + SmallVector originTypes; + converter.tryDecomposeType(resultType, originTypes); + for (auto origin : originTypes) { + Type converted = converter.convertType(origin); + auto kind = converter.getResultConversionKind(origin, converted); + if (kind == BufferAssignmentTypeConverter::AppendToArgumentsList) + conversion.addInputs(converted); + else + // kind = BufferAssignmentTypeConverter::KeepAsFunctionResult + newResultTypes.push_back(converted); + } + } + + if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), converter, + &conversion))) + return failure(); + + // Update the signature of the function. + rewriter.updateRootInPlace(funcOp, [&] { + funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), + newResultTypes)); + }); + return success(); +} + +//===----------------------------------------------------------------------===// +// BufferAssignmentCallOpConverter +//===----------------------------------------------------------------------===// + +namespace { +// This class represents a mapping from a result to a list of values and some +// results that have not yet constructed. Instead, the indices of these +// results in the operation that will be constructed are known. They will be +// replaced with the actual values when they are available. The order of +// adding to this mapping is important. +class CallOpResultMapping { +public: + CallOpResultMapping() { order = 0; }; + + /// Add an available value to the mapping. + void addMapping(Value value) { toValuesMapping.push_back({order++, value}); } + + /// Add the index of unavailble result value to the mapping. + void addMapping(unsigned index) { + toIndicesMapping.push_back({order++, index}); + } + + /// This method returns the mapping values list. The unknown result values + /// that only their indicies are available are replaced with their values. + void getMappingValues(ValueRange valuesToReplaceIndices, + SmallVectorImpl &values) { + // Append available values to the list. + SmallVector, 2> res(toValuesMapping.begin(), + toValuesMapping.end()); + // Replace the indices with the actual values. + llvm::for_each( + toIndicesMapping, [&](const std::pair &entry) { + assert(entry.second < valuesToReplaceIndices.size() && + "The value index is out of range."); + res.push_back({entry.first, valuesToReplaceIndices[entry.second]}); + }); + // Sort the values based on their adding orders. + llvm::sort(res, [](const std::pair &v1, + const std::pair &v2) { + return v1.first < v2.first; + }); + // Fill the values. + llvm::for_each(res, [&](const std::pair &entry) { + values.push_back(entry.second); + }); + } + +private: + /// Keeping the inserting order of mapping values. + int order; + + /// Containing the mapping values with their inserting orders. + SmallVector, 2> toValuesMapping; + + /// Containing the indices of result values with their inserting orders. + SmallVector, 2> toIndicesMapping; +}; +} // namespace + +/// Performs the actual rewriting step. +LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite( + CallOp callOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + + Location loc = callOp.getLoc(); + OpBuilder builder(callOp); + SmallVector newOperands; + + // TODO: if the CallOp references a FuncOp that only has a declaration (e.g. + // to an externally defined symbol like an external library calls), only + // convert if some special attribute is set. + // This will allow more control of interop across ABI boundaries. + + // Create the operands list of the new `CallOp`. It unpacks the decomposable + // values if a decompose callback function has been provided by the user. + for (auto operand : operands) { + SmallVector values; + this->converter.tryDecomposeValue(builder, loc, operand.getType(), operand, + values); + newOperands.append(values.begin(), values.end()); + } + + // Create the new result types for the new `CallOp` and a mapping from the old + // result to new value(s). + SmallVector newResultTypes; + SmallVector mappings; + mappings.resize(callOp.getNumResults()); + for (auto result : llvm::enumerate(callOp.getResults())) { + SmallVector originTypes; + converter.tryDecomposeType(result.value().getType(), originTypes); + auto &resultMapping = mappings[result.index()]; + for (Type origin : originTypes) { + Type converted = converter.convertType(origin); + auto kind = converter.getResultConversionKind(origin, converted); + if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult) { + newResultTypes.push_back(converted); + // The result value is not yet available. Its index is kept and it is + // replaced with the actual value of the new `CallOp` later. + resultMapping.addMapping(newResultTypes.size() - 1); + } else { + // kind = BufferAssignmentTypeConverter::AppendToArgumentsList + MemRefType memref = converted.dyn_cast(); + if (!memref) + return callOp.emitError("Cannot allocate for a non-Memref type"); + Value alloc = rewriter.create(loc, memref); + newOperands.push_back(alloc); + resultMapping.addMapping(alloc); + } + } + } + + CallOp newCallOp = rewriter.create(loc, callOp.getCallee(), + newResultTypes, newOperands); + + // Build a replacing value for each result to replace its uses. If a result + // has multiple mapping values, it needs to be packed to a single value. + OpBuilder nextBuilder(callOp.getOperation()->getNextNode()); + SmallVector replacedValues; + replacedValues.reserve(callOp.getNumResults()); + for (unsigned i = 0, e = callOp.getNumResults(); i < e; ++i) { + SmallVector valuesToPack; + mappings[i].getMappingValues(newCallOp.getResults(), valuesToPack); + if (valuesToPack.empty()) { + // No replacement is required. + replacedValues.push_back(nullptr); + } else if (valuesToPack.size() == 1) { + replacedValues.push_back(valuesToPack.front()); + } else { + // Values need to be packed using callback function. The same callback + // that is used for materializeArgumentConversion is used for packing. + Value packed = converter.materializeArgumentConversion( + nextBuilder, loc, callOp.getType(i), valuesToPack); + replacedValues.push_back(packed); + } + } + rewriter.replaceOp(callOp, replacedValues); + return success(); +} diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(MLIRTransforms BufferPlacement.cpp + Bufferize.cpp Canonicalizer.cpp CopyRemoval.cpp CSE.cpp