diff --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h --- a/mlir/include/mlir/Transforms/BufferPlacement.h +++ b/mlir/include/mlir/Transforms/BufferPlacement.h @@ -52,6 +52,111 @@ Operation *operation; }; +/// A helper type converter class for using inside Buffer Assignment operation +/// conversion patterns. The default constructor keeps all the types intact +/// except for the ranked-tensor types which is converted to memref types. +class BufferAssignmentTypeConverter : public TypeConverter { +public: + /// This enum is for showing how buffer placement operation converters should + /// conduct with certain result type after type conversion. This value can be + /// set/get for each specific type using setResultConversionKind or + /// getResultConversionKind. + enum ResultConversionKind { AppendToArgumentsList, KeepAsFunctionResult }; + + BufferAssignmentTypeConverter(); + + /// This method tries to decompose a value of a certain type using provided + /// decompose callback functions. If it is unable to do so, the original value + /// is returned. + void tryDecomposeValue(OpBuilder &, Location, Type, Value, + SmallVectorImpl &); + + /// This method tries to decompose a type using provided decompose callback + /// functions. If it is unable to do so, the original type is returned. + void tryDecomposeType(Type, SmallVectorImpl &); + + /// This method registers a callback function that will be called to decompose + /// a value of a certain type into several values. + template ::template arg_t<2>> + void addDecomposeValueConversion(FnT &&callback) { + decomposeValueConversions.emplace_back( + wrapDecomposeValueConversionCallback(std::forward(callback))); + } + + /// This method registers a callback function that will be called to decompose + /// a type into several types. + template ::template arg_t<0>> + void addDecomposeTypeConversion(FnT &&callback) { + auto wrapper = + wrapDecomposeTypeConversionCallback(std::forward(callback)); + decomposeTypeConversions.emplace_back(wrapper); + addConversion(std::forward(callback)); + } + + /// This method returns ResultConversionKind for the mapping from `origin` + /// type to `input` type. + ResultConversionKind getResultConversionKind(Type origin, Type input); + + /// This method registers ResultConversionKind for the mapping from type 'T' + /// to type 'U'. + template + void setResultConversionKind(ResultConversionKind kind) { + assert((kind != AppendToArgumentsList || + llvm::is_one_of::value) && + "Only the memref typed values can be set to be appended to the " + "function argument list at the moment"); + resultTypeConversions.emplace_back( + [&](Type origin, Type input) -> Optional { + if (origin.template isa() && input.template isa()) + return kind; + return llvm::None; + }); + } + +private: + using DecomposeValueConversionCallFn = std::function( + OpBuilder &, Location, Type, Value, SmallVectorImpl &)>; + + using DecomposeTypeConversionCallFn = + std::function(Type, SmallVectorImpl &)>; + + using ResultConversionKindFn = + std::function(Type, Type)>; + + /// Generate a wrapper for the given decompose value conversion callback. + template + DecomposeValueConversionCallFn + wrapDecomposeValueConversionCallback(FnT &&callback) { + return [callback = std::forward(callback)]( + OpBuilder &builder, Location loc, Type type, Value value, + SmallVectorImpl &newValues) -> Optional { + if (T derivedType = type.dyn_cast()) + return callback(builder, loc, derivedType, value, newValues); + return llvm::None; + }; + } + + /// Generate a wrapper for the given decompose type conversion callback. + template + DecomposeTypeConversionCallFn + wrapDecomposeTypeConversionCallback(FnT &&callback) { + return [callback = std::forward(callback)]( + Type type, + SmallVectorImpl &results) -> Optional { + T derivedType = type.dyn_cast(); + if (!derivedType) + return llvm::None; + return callback(derivedType, results); + }; + } + + SmallVector resultTypeConversions; + SmallVector decomposeValueConversions; + SmallVector decomposeTypeConversions; +}; + /// Helper conversion pattern that encapsulates a BufferAssignmentPlacer /// instance. Sample usage: /// class CustomConversionPattern : public @@ -68,43 +173,22 @@ public: explicit BufferAssignmentOpConversionPattern( MLIRContext *context, BufferAssignmentPlacer *bufferAssignment = nullptr, - TypeConverter *converter = nullptr, PatternBenefit benefit = 1) + BufferAssignmentTypeConverter *converter = nullptr, + PatternBenefit benefit = 1) : OpConversionPattern(context, benefit), - bufferAssignment(bufferAssignment), converter(converter) {} + bufferAssignment(bufferAssignment), converter(converter) { + assert(converter && "The type converter has not been defined"); + } protected: BufferAssignmentPlacer *bufferAssignment; - TypeConverter *converter; -}; - -/// A helper type converter class for using inside Buffer Assignment operation -/// conversion patterns. The default constructor keeps all the types intact -/// except for the ranked-tensor types which is converted to memref types. -class BufferAssignmentTypeConverter : public TypeConverter { -public: - BufferAssignmentTypeConverter(); - - /// A helper function to check if `type` has been converted from non-memref - /// type to memref. - static bool isConvertedMemref(Type type, Type before); + BufferAssignmentTypeConverter *converter; }; -namespace detail { - -/// Converts the signature of the function based on whether the function is -/// allowed to return memref typed results or not using -/// `allowMemrefFunctionResults` parameter. If this option is false, then it -/// adds an extra function argument as an output buffer for each function result -/// which is going to be a memref type only after type conversion. The -/// other function result types remain unchanged. If -/// `allowMemrefFunctionResults` is true, the types are converted in place. -/// Any changes in function signature need to be applied -/// to return and caller operations. `BufferAssignmentReturnOpConverter` and -/// `BufferAssignmentCallOpConverter` are two helper function that match the -/// return and caller operation with the new function signature. Furthermore, -/// `BufferAssignmentTypeConverter` is a helper `TypeConverter` for converting -/// tensor typed values to memref typed ones. -template +/// Converts the signature of the function using BufferAssignmentTypeConverter. +/// Each result type of the function is kept as a function result or appended to +/// the function arguments list based on ResultConversionKind for the converted +/// result type. class BufferAssignmentFuncOpConverter : public BufferAssignmentOpConversionPattern { public: @@ -112,58 +196,16 @@ FuncOp>::BufferAssignmentOpConversionPattern; /// Performs the actual signature rewriting step. - LogicalResult - matchAndRewrite(mlir::FuncOp funcOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - if (!converter) - return funcOp.emitError("The type converter has not been defined for " - "BufferAssignmentFuncOpConverter"); - auto funcType = funcOp.getType(); - - // Convert function arguments using the provided TypeConverter. - TypeConverter::SignatureConversion conversion(funcType.getNumInputs()); - for (auto argType : llvm::enumerate(funcType.getInputs())) - conversion.addInputs(argType.index(), - converter->convertType(argType.value())); - - // If allowMemrefFunctionResults is false and a function result type is not - // a memref but it would be a memref after type conversion, a new argument - // should be appended to the function arguments list for this result. - // Otherwise, it remains unchanged as a function result. - SmallVector newResultTypes; - newResultTypes.reserve(funcOp.getNumResults()); - for (Type resType : funcType.getResults()) { - Type convertedType = converter->convertType(resType); - if (!allowMemrefFunctionResults && - BufferAssignmentTypeConverter::isConvertedMemref(convertedType, - resType)) - conversion.addInputs(convertedType); - else - newResultTypes.push_back(convertedType); - } - if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter, - &conversion))) - return failure(); - - // Update the signature of the function. - rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), - newResultTypes)); - }); - return success(); - } + LogicalResult matchAndRewrite(mlir::FuncOp, ArrayRef, + ConversionPatternRewriter &) const; }; /// Rewrites the `ReturnOp` to conform with the changed function signature. -/// if allowMemrefFunctionResults is false, operands that correspond to return -/// values and have been rewritten from illegal typed results to memref -/// arguments are dropped. In their place, a corresponding copy operation from -/// the operand to the output function argument is inserted. Otherwise, the -/// memref typed operands are returned. -/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter, -/// allowMemrefFunctionResults must be set/unset for both. +/// Operands that correspond to return values and their types have been set to +/// AppendToArgumentsList are dropped. In their place, a corresponding copy +/// operation from the operand to the target function argument is inserted. template + typename CopyOpTy> class BufferAssignmentReturnOpConverter : public BufferAssignmentOpConversionPattern { public: @@ -174,44 +216,49 @@ LogicalResult matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - // If the memref typed results can be returned as function results, the new - // `ReturnOp` should only return the type converted operands. - if (allowMemrefFunctionResults) { - rewriter.replaceOpWithNewOp(returnOp, operands); - return success(); + Location loc = returnOp.getLoc(); + + // Split the operands depending on whether they need a copy operation or + // they remain as operands of the return operation. If an operand is + // decomposable and a decompose callback function has been provided by the + // user, it will be unpacked. + SmallVector newOperands, needCopyOperands; + OpBuilder builder(returnOp); + for (auto operand : llvm::enumerate(operands)) { + SmallVector values; + this->converter->tryDecomposeValue( + builder, loc, operand.value().getType(), operand.value(), values); + Type type = returnOp.getOperand(operand.index()).getType(); + SmallVector originTypes; + this->converter->tryDecomposeType(type, originTypes); + for (auto value : llvm::enumerate(values)) { + Type origin = originTypes[value.index()]; + Type converted = value.value().getType(); + auto kind = this->converter->getResultConversionKind(origin, converted); + if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult) + newOperands.push_back(value.value()); + else + // kind = BufferAssignmentTypeConverter::AppendToArgumentsList + needCopyOperands.push_back(value.value()); + } } - // Split the operands by their kinds whether they are converted memref or - // not. - SmallVector needCopyOperands, newOperands; - unsigned operandsSize = operands.size(); - needCopyOperands.reserve(operandsSize); - newOperands.reserve(operandsSize); - for (auto operand : llvm::enumerate(operands)) - if (BufferAssignmentTypeConverter::isConvertedMemref( - operand.value().getType(), - returnOp.getOperand(operand.index()).getType())) - needCopyOperands.push_back(operand.value()); - else - newOperands.push_back(operand.value()); - + // Insert Copy operations instead for the operands that have been removed + // from operand list and appended to the function arguments list. Block &entryBlock = returnOp.getParentRegion()->front(); unsigned numFuncArgs = entryBlock.getNumArguments(); - - // Find the index of the first destination buffer. - assert(needCopyOperands.size() <= numFuncArgs && - "The number of operands of return operation is more than the " - "number of function arguments."); + if (needCopyOperands.size() > numFuncArgs) + return returnOp.emitError( + "The number of operands that need Copy operations is more " + "than the number of target function arguments. The converted " + "function signature is "); unsigned destArgNum = numFuncArgs - needCopyOperands.size(); rewriter.setInsertionPoint(returnOp); for (Value operand : needCopyOperands) { - // Insert a `CopyOp` for each converted memref-type operand. - rewriter.create(returnOp.getLoc(), operand, + rewriter.create(loc, operand, entryBlock.getArgument(destArgNum)); ++destArgNum; } - - // Insert the new target Return operation. rewriter.replaceOpWithNewOp(returnOp, newOperands); return success(); } @@ -219,94 +266,32 @@ /// Rewrites the `CallOp` to match its operands and results with the signature /// of the callee after rewriting the callee with -/// BufferAssignmentFuncOpConverter. If allowMemrefFunctionResults is false, a -/// buffer is allocated as an output buffer only for each memref typed result -/// that has been rewritten. The new allocated buffer is passed through the -/// operands list of the new `CallOp`. -/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter, -/// allowMemrefFunctionResults must be set/unset for both. -template +/// BufferAssignmentFuncOpConverter. class BufferAssignmentCallOpConverter : public BufferAssignmentOpConversionPattern { public: using BufferAssignmentOpConversionPattern< CallOp>::BufferAssignmentOpConversionPattern; - LogicalResult - matchAndRewrite(CallOp callOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - if (!converter) - return callOp.emitError("The type converter has not been defined for " - "BufferAssignmentCallOpConverter"); - Location loc = callOp.getLoc(); - - // If the memref typed results can be returned as function results, there is - // no need to create output buffers. It is only required to convert the type - // of operands and results in place for creating the new `CallOp`. - if (allowMemrefFunctionResults) { - SmallVector resultTypes; - resultTypes.reserve(callOp.getNumResults()); - for (Type type : callOp.getResultTypes()) - resultTypes.push_back(converter->convertType(type)); - rewriter.replaceOpWithNewOp(callOp, callOp.getCallee(), - resultTypes, operands); - return success(); - } - - SmallVector newOperands, replacingValues; - SmallVector newResultTypes; - unsigned numResults = callOp.getNumResults(); - newOperands.reserve(numResults + operands.size()); - newOperands.append(operands.begin(), operands.end()); - newResultTypes.reserve(numResults); - replacingValues.reserve(numResults); - - // For each memref result of `CallOp` which has not been a memref before - // the type conversion, a new buffer is allocated and passed to the operands - // list of the new `CallOp`. Otherwise, it remains as a caller result. - for (Value result : callOp.getResults()) { - Type currType = result.getType(); - Type newType = converter->convertType(result.getType()); - if (BufferAssignmentTypeConverter::isConvertedMemref(newType, currType)) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.restoreInsertionPoint(bufferAssignment->computeAllocPosition( - result.dyn_cast())); - Value alloc = - rewriter.create(loc, newType.dyn_cast()); - newOperands.push_back(alloc); - replacingValues.push_back(alloc); - } else { - newResultTypes.push_back(currType); - - // No replacing is required. - replacingValues.push_back(nullptr); - } - } - - // Creating the new `CallOp`. - rewriter.create(loc, callOp.getCallee(), newResultTypes, - newOperands); - - // Replacing the results of the old `CallOp`. - rewriter.replaceOp(callOp, replacingValues); - return success(); - } + /// Performs the actual rewriting step. + LogicalResult matchAndRewrite(CallOp, ArrayRef, + ConversionPatternRewriter &) const; }; -} // end namespace detail /// Populates `patterns` with the conversion patterns of buffer /// assignment. template + typename CopyOpTy> static void populateWithBufferAssignmentOpConversionPatterns( MLIRContext *context, BufferAssignmentPlacer *placer, - TypeConverter *converter, OwningRewritePatternList *patterns) { + BufferAssignmentTypeConverter *converter, + OwningRewritePatternList *patterns) { // clang-format off patterns->insert< - detail::BufferAssignmentCallOpConverter, - detail::BufferAssignmentFuncOpConverter, - detail::BufferAssignmentReturnOpConverter - + BufferAssignmentCallOpConverter, + BufferAssignmentFuncOpConverter, + BufferAssignmentReturnOpConverter + >(context, placer, converter); // clang-format on } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp @@ -100,11 +100,11 @@ /// tensors to buffers. static void populateConvertLinalgOnTensorsToBuffersPattern( MLIRContext *context, BufferAssignmentPlacer *placer, - TypeConverter *converter, OwningRewritePatternList *patterns) { + BufferAssignmentTypeConverter *converter, + OwningRewritePatternList *patterns) { populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp, - /*allowMemrefFunctionResults=*/false>(context, placer, converter, - patterns); + mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer, + converter, patterns); patterns->insert(context, placer, converter); } @@ -141,6 +141,9 @@ converter.isLegal(&funcOp.getBody()); }); + converter.setResultConversionKind( + BufferAssignmentTypeConverter::AppendToArgumentsList); + // Walk over all the functions to apply buffer assignment. getOperation().walk([&](FuncOp function) -> WalkResult { OwningRewritePatternList patterns; diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp --- a/mlir/lib/Transforms/BufferPlacement.cpp +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -710,9 +710,223 @@ }); } -/// Checks if `type` has been converted from non-memref type to memref. -bool BufferAssignmentTypeConverter::isConvertedMemref(Type type, Type before) { - return type.isa() && !before.isa(); +/// This method tries to decompose a value of a certain type using provided +/// decompose callback functions. If it is unable to do so, the original value +/// is returned. +void BufferAssignmentTypeConverter::tryDecomposeValue( + OpBuilder &builder, Location loc, Type type, Value value, + SmallVectorImpl &results) { + for (auto conversion : decomposeValueConversions) + if (conversion(builder, loc, type, value, results) != llvm::None) + return; + results.push_back(value); +} + +/// This method tries to decompose a type using provided decompose callback +/// functions. If it is unable to do so, the original type is returned. +void BufferAssignmentTypeConverter::tryDecomposeType( + Type type, SmallVectorImpl &types) { + for (auto conversion : decomposeTypeConversions) + if (conversion(type, types) != llvm::None) + return; + types.push_back(type); +} + +/// This method returns ResultConversionKind for the input type. +BufferAssignmentTypeConverter::ResultConversionKind +BufferAssignmentTypeConverter::getResultConversionKind(Type origin, + Type converted) { + for (auto conversion : resultTypeConversions) { + auto res = conversion(origin, converted); + if (res != llvm::None) + return res.getValue(); + } + return KeepAsFunctionResult; +} + +//===----------------------------------------------------------------------===// +// BufferAssignmentFuncOpConverter +//===----------------------------------------------------------------------===// + +/// Performs the actual function signature rewriting step. +LogicalResult BufferAssignmentFuncOpConverter::matchAndRewrite( + mlir::FuncOp funcOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + auto funcType = funcOp.getType(); + + // Convert function arguments using the provided TypeConverter. + TypeConverter::SignatureConversion conversion(funcType.getNumInputs()); + for (auto argType : llvm::enumerate(funcType.getInputs())) { + SmallVector decomposedTypes, convertedTypes; + converter->tryDecomposeType(argType.value(), decomposedTypes); + converter->convertTypes(decomposedTypes, convertedTypes); + conversion.addInputs(argType.index(), convertedTypes); + } + + // Convert the result types of the function. + SmallVector newResultTypes; + newResultTypes.reserve(funcOp.getNumResults()); + for (Type resultType : funcType.getResults()) { + SmallVector originTypes; + converter->tryDecomposeType(resultType, originTypes); + for (auto origin : originTypes) { + Type converted = converter->convertType(origin); + auto kind = converter->getResultConversionKind(origin, converted); + if (kind == BufferAssignmentTypeConverter::AppendToArgumentsList) + conversion.addInputs(converted); + else + // kind = BufferAssignmentTypeConverter::KeepAsFunctionResult + newResultTypes.push_back(converted); + } + } + + if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter, + &conversion))) + return failure(); + + // Update the signature of the function. + rewriter.updateRootInPlace(funcOp, [&] { + funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), + newResultTypes)); + }); + return success(); +} + +//===----------------------------------------------------------------------===// +// BufferAssignmentCallOpConverter +//===----------------------------------------------------------------------===// + +/// Performs the actual rewriting step. +LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite( + CallOp callOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + + // This class represents a mapping from a result to a list of values and some + // results that have not yet constructed. Instead, the indices of these + // results in the operation that will be constructed are known. They will be + // replaced with the actual values when they are available. The order of + // adding to this mapping is important. + class ResultMapping { + public: + ResultMapping() { order = 0; }; + + /// Add an available value to the mapping. + void addMapping(Value value) { + toValuesMapping.push_back({order++, value}); + } + + /// Add the index of unavailble result value to the mapping. + void addMapping(unsigned index) { + toIndicesMapping.push_back({order++, index}); + } + + /// This method returns the mapping values list. The unknown result values + /// that only their indicies are available are replaced with their values. + void getMappingValues(ValueRange valuesToReplaceIndices, + SmallVectorImpl &values) { + // Append available values to the list. + SmallVector, 2> res(toValuesMapping.begin(), + toValuesMapping.end()); + // Replace the indices with the actual values. + llvm::for_each( + toIndicesMapping, [&](const std::pair &entry) { + assert(entry.second < valuesToReplaceIndices.size() && + "The value index is out of range."); + res.push_back({entry.first, valuesToReplaceIndices[entry.second]}); + }); + // Sort the values based on their adding orders. + llvm::sort(res, [](const std::pair &v1, + const std::pair &v2) { + return v1.first < v2.first; + }); + // Fill the values. + llvm::for_each(res, [&](const std::pair &entry) { + values.push_back(entry.second); + }); + } + + private: + /// Keeping the inserting order of mapping values. + int order; + + /// Containing the mapping values with their inserting orders. + SmallVector, 2> toValuesMapping; + + /// Containing the indices of result values with their inserting orders. + SmallVector, 2> toIndicesMapping; + }; + + Location loc = callOp.getLoc(); + OpBuilder builder(callOp); + SmallVector newOperands; + + // Create the operands list of the new `CallOp`. It unpacks the decomposable + // values if a decompose callback function has been provided by the user. + for (auto operand : operands) { + SmallVector values; + this->converter->tryDecomposeValue(builder, loc, operand.getType(), operand, + values); + newOperands.append(values.begin(), values.end()); + } + + // Create the new result types for the new `CallOp` and a mapping from the old + // result to new value(s). + SmallVector newResultTypes; + SmallVector mappings; + mappings.resize(callOp.getNumResults()); + for (auto result : llvm::enumerate(callOp.getResults())) { + SmallVector originTypes; + converter->tryDecomposeType(result.value().getType(), originTypes); + auto &resultMapping = mappings[result.index()]; + for (Type origin : originTypes) { + Type converted = converter->convertType(origin); + auto kind = converter->getResultConversionKind(origin, converted); + if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult) { + newResultTypes.push_back(converted); + // The result value is not yet available. Its index is kept and it is + // replaced with the actual value of the new `CallOp` later. + resultMapping.addMapping(newResultTypes.size() - 1); + } else { + // kind = BufferAssignmentTypeConverter::AppendToArgumentsList + OpBuilder::InsertionGuard guard(rewriter); + rewriter.restoreInsertionPoint( + bufferAssignment->computeAllocPosition(result.value())); + MemRefType memref = converted.dyn_cast(); + if (!memref) + return callOp.emitError("Cannot allocate for a non-Memref type"); + Value alloc = rewriter.create(loc, memref); + newOperands.push_back(alloc); + resultMapping.addMapping(alloc); + } + } + } + + CallOp newCallOp = rewriter.create(loc, callOp.getCallee(), + newResultTypes, newOperands); + + // Build a replacing value for each result to replace its uses. If a result + // has multiple mapping values, it needs to be packed to a single value. + OpBuilder nextBuilder(callOp.getOperation()->getNextNode()); + SmallVector replacedValues; + replacedValues.reserve(callOp.getNumResults()); + for (unsigned i = 0, e = callOp.getNumResults(); i < e; ++i) { + SmallVector valuesToPack; + mappings[i].getMappingValues(newCallOp.getResults(), valuesToPack); + if (valuesToPack.empty()) { + // No replacement is required. + replacedValues.push_back(nullptr); + } else if (valuesToPack.size() == 1) { + replacedValues.push_back(valuesToPack.front()); + } else { + // Values need to be packed using callback function. The same callback + // that is used for materializeArgumentConversion is used for packing. + Value packed = converter->materializeArgumentConversion( + nextBuilder, loc, callOp.getType(i), valuesToPack); + replacedValues.push_back(packed); + } + } + rewriter.replaceOp(callOp, replacedValues); + return success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir --- a/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir +++ b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir @@ -111,7 +111,73 @@ // CHECK: %[[Y:.*]]:2 = call @callee(%[[X]]#0) // CHECK: return %[[Y]]#0 +// ----- + +// Test case: Testing BufferAssginmnetCallOpConverter to see if it matches with the +// signature of the new signature of the callee function when there are tuple typed +// args and results. BufferAssginmentTypeConverter is set to flatten tuple typed +// arguments. The tuple typed values should be decomposed and composed using +// get_tuple_element and make_tuple operations of test dialect. Tensor types are +// converted to Memref. Memref typed function results remain as function results. +// CHECK-LABEL: func @callee +func @callee(%arg0: tuple,i1, tensor<5xf32>>) -> (tuple,i1, tensor<5xf32>>){ + return %arg0 : tuple,i1, tensor<5xf32>> +} +// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>) +// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) +// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) +// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32} +// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]] +// CHECK-LABEL: func @caller +func @caller(%arg0: tuple,i1, tensor<5xf32>>) -> tuple,i1, tensor<5xf32>>{ + %x0 = call @callee(%arg0) : (tuple,i1, tensor<5xf32>>) -> (tuple,i1, tensor<5xf32>>) + %y0 = call @callee(%x0) : (tuple,i1, tensor<5xf32>>) -> (tuple,i1, tensor<5xf32>>) + return %y0 : tuple,i1, tensor<5xf32>> +} +// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>) +// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) +// CHECK-NEXT: %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) +// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 2 : i32} +// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]) +// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>) +// CHECK-NEXT: %[[RESULT_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2) +// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 2 : i32} +// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]) +// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>) +// CHECK-NEXT: %[[RETURN_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2) +// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 2 : i32} +// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]] +// ----- +// Test case: Testing BufferAssginmnetFuncOpConverter and +// BufferAssginmentReturnOpConverter to see if the return operation matches with +// the new function signature when there are tuple typed args and results. +// BufferAssginmentTypeConverter is set to flatten tuple typed arguments. The tuple +// typed values should be decomposed and composed using get_tuple_element and +// make_tuple operations of test dialect. Tensor types are converted to Memref. +// Memref typed function results remain as function results. + +// CHECK-LABEL: func @decompose_tuple_typed_function_args_and_results +func @decompose_tuple_typed_function_args_and_results(%arg0: tuple, %arg1: tensor<10xf32>, %arg2: tuple>) -> (tuple>, tensor<10xf32>, tuple){ + return %arg2, %arg1, %arg0 : tuple>, tensor<10xf32>, tuple +} +// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<10xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: memref<5xf32> +// CHECK-SAME: (i1, memref<5xf32>, memref<10xf32>, i1, f32) +// CHECK-NEXT: %[[FIRST_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) +// CHECK-NEXT: %[[SECOND_TUPLE:.*]] = "test.make_tuple"(%[[ARG3]], %[[ARG4]]) +// CHECK-NEXT: %[[SECOND_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[SECOND_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: %[[FIRST_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[FIRST_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[SECOND_TUPLE_SECOND_ELEM]], %[[ARG2]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]] diff --git a/mlir/test/Transforms/buffer-placement-preparation.mlir b/mlir/test/Transforms/buffer-placement-preparation.mlir --- a/mlir/test/Transforms/buffer-placement-preparation.mlir +++ b/mlir/test/Transforms/buffer-placement-preparation.mlir @@ -285,8 +285,93 @@ // CHECK: linalg.copy(%[[Y0]], %[[CALLER_RESULT]]) // CHECK: return +// ----- + // CHECK-LABEL: func @func_with_unranked_arg func @func_with_unranked_arg(%arg0: tensor<*xf32>) { return } // CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) + +// ----- + +// Test case: Testing BufferAssginmnetCallOpConverter to see if it matches with the +// signature of the new signature of the callee function when there are tuple typed +// args and results. BufferAssginmentTypeConverter is set to flatten tuple typed +// arguments. The tuple typed values should be decomposed and composed using +// get_tuple_element and make_tuple operations of test dialect. Tensor types are +// converted to Memref. Memref typed function results are appended to the function +// arguments list. + +// CHECK-LABEL: func @callee +func @callee(%arg0: tuple,i1, tensor<5xf32>>) -> (tuple,i1, tensor<5xf32>>){ + return %arg0 : tuple,i1, tensor<5xf32>> +} +// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<2xf32>, %[[RESULT1:.*]]: memref<5xf32>) +// CHECK-SAME: i1 +// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) +// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32} +// CHECK-NEXT: linalg.copy(%[[FIRST_ELEM]], %[[RESULT0]]) +// CHECK-NEXT: linalg.copy(%[[THIRD_ELEM]], %[[RESULT1]]) +// CHECK-NEXT: return %[[SECOND_ELEM]] + + +// CHECK-LABEL: func @caller +func @caller(%arg0: tuple,i1, tensor<5xf32>>) -> tuple,i1, tensor<5xf32>>{ + %x0 = call @callee(%arg0) : (tuple,i1, tensor<5xf32>>) -> (tuple,i1, tensor<5xf32>>) + %y0 = call @callee(%x0) : (tuple,i1, tensor<5xf32>>) -> (tuple,i1, tensor<5xf32>>) + return %y0 : tuple,i1, tensor<5xf32>> +} +// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<2xf32>, %[[RESULT1:.*]]: memref<5xf32>) +// CHECK-SAME: i1 +// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) +// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32} +// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc() +// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc() +// CHECK-NEXT: %[[CALLEE_RESULT:.*]] = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]], %[[FIRST_ALLOC]], %[[SECOND_ALLOC]]) +// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>, memref<2xf32>, memref<5xf32>) -> i1 +// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[FIRST_ALLOC]], %[[CALLEE_RESULT]], %[[SECOND_ALLOC]]) +// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32} +// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc() +// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc() +// CHECK-NEXT: %[[CALLEE_RESULT:.*]] = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]], %[[FIRST_ALLOC]], %[[SECOND_ALLOC]]) +// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>, memref<2xf32>, memref<5xf32>) -> i1 +// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[FIRST_ALLOC]], %[[CALLEE_RESULT]], %[[SECOND_ALLOC]]) +// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32} +// CHECK-NEXT: linalg.copy(%[[FIRST_ELEM]], %[[RESULT0]]) +// CHECK-NEXT: linalg.copy(%[[THIRD_ELEM]], %[[RESULT1]]) +// CHECK-NEXT: return %[[SECOND_ELEM]] + +// ----- + +// Test case: Testing BufferAssginmnetFuncOpConverter and +// BufferAssginmentReturnOpConverter to see if the return operation matches with +// the new function signature when there are tuple typed args and results. +// BufferAssginmentTypeConverter is set to flatten tuple typed arguments. The tuple +// typed values should be decomposed and composed using get_tuple_element and +// make_tuple operations of test dialect. Tensor types are converted to Memref. +// Memref typed function results are appended to the function arguments list. + +// CHECK-LABEL: func @decompose_tuple_typed_function_args_and_results +func @decompose_tuple_typed_function_args_and_results(%arg0: tuple, %arg1: tensor<10xf32>, %arg2: tuple>) -> (tuple>, tensor<10xf32>, tuple){ + return %arg2, %arg1, %arg0 : tuple>, tensor<10xf32>, tuple +} +// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<10xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<5xf32>, %[[RESULT1:.*]]: memref<10xf32> +// CHECK-SAME: (i1, i1, f32) +// CHECK-NEXT: %[[FIRST_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) +// CHECK-NEXT: %[[SECOND_TUPLE:.*]] = "test.make_tuple"(%[[ARG3]], %[[ARG4]]) +// CHECK-NEXT: %[[SECOND_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[SECOND_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: %[[FIRST_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[FIRST_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: linalg.copy(%[[SECOND_TUPLE_SECOND_ELEM]], %[[RESULT0]]) +// CHECK-NEXT: linalg.copy(%[[ARG2]], %[[RESULT1]]) +// CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]] diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1473,4 +1473,27 @@ }]; } +def GetTupleElementOp: TEST_Op<"get_tuple_element"> { + let description = [{ + Test op that returns a specified element of the tuple. + }]; + + let arguments = (ins + TupleOf<[AnyType]>, + I32Attr:$index + ); + let results = (outs AnyType); +} + +def MakeTupleOp: TEST_Op<"make_tuple"> { + let description = [{ + Test op that creates a tuple value from a list of values. + }]; + + let arguments = (ins + Variadic:$inputs + ); + let results = (outs TupleOf<[AnyType]>); +} + #endif // TEST_OPS diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -11,6 +11,8 @@ // //===----------------------------------------------------------------------===// +#include "TestDialect.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/IR/Function.h" #include "mlir/IR/Operation.h" @@ -109,10 +111,11 @@ void populateTensorLinalgToBufferLinalgConversionPattern( MLIRContext *context, BufferAssignmentPlacer *placer, - TypeConverter *converter, OwningRewritePatternList *patterns) { + BufferAssignmentTypeConverter *converter, + OwningRewritePatternList *patterns) { populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp, - allowMemrefFunctionResults>(context, placer, converter, patterns); + mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer, + converter, patterns); patterns->insert(context, placer, converter); } @@ -123,6 +126,8 @@ // Mark all Standard operations legal. target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); // Mark all Linalg operations illegal as long as they work on tensors. auto isLegalOperation = [&](Operation *op) { @@ -145,6 +150,42 @@ converter.isLegal(&funcOp.getBody()); }); + auto kind = allowMemrefFunctionResults + ? BufferAssignmentTypeConverter::KeepAsFunctionResult + : BufferAssignmentTypeConverter::AppendToArgumentsList; + converter.setResultConversionKind(kind); + converter.setResultConversionKind( + kind); + + converter.addDecomposeTypeConversion( + [](TupleType tupleType, SmallVectorImpl &types) { + tupleType.getFlattenedTypes(types); + return success(); + }); + + converter.addArgumentMaterialization( + [](OpBuilder &builder, TupleType resultType, ValueRange inputs, + Location loc) -> Optional { + if (inputs.size() == 1) + return llvm::None; + TypeRange TypeRange = inputs.getTypes(); + SmallVector types(TypeRange.begin(), TypeRange.end()); + TupleType tuple = TupleType::get(types, builder.getContext()); + mlir::Value value = builder.create(loc, tuple, inputs); + return value; + }); + + converter.addDecomposeValueConversion([](OpBuilder &builder, Location loc, + TupleType resultType, Value value, + SmallVectorImpl &values) { + for (unsigned i = 0, e = resultType.size(); i < e; ++i) { + Value res = builder.create( + loc, resultType.getType(i), value, builder.getI32IntegerAttr(i)); + values.push_back(res); + } + return success(); + }); + // Walk over all the functions to apply buffer assignment. this->getOperation().walk([&](FuncOp function) -> WalkResult { OwningRewritePatternList patterns;