diff --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h --- a/mlir/include/mlir/Transforms/BufferPlacement.h +++ b/mlir/include/mlir/Transforms/BufferPlacement.h @@ -52,6 +52,74 @@ Operation *operation; }; +/// A helper type converter class for using inside Buffer Assignment operation +/// conversion patterns. The default constructor keeps all the types intact +/// except for the ranked-tensor types which is converted to memref types. +class BufferAssignmentTypeConverter : public TypeConverter { +public: + /// This enum is for showing how buffer placement operation converters should + /// conduct with certain result type after type conversion. This value can be + /// set/get for each specific type using setResultConversionKind or + /// getResultConversionKind. + enum ResultConversionKind { AppendToArgumentsList, KeepAsFunctionResult }; + + BufferAssignmentTypeConverter(); + + /// This method tries to decompose a value of a certain type using provided + /// decompose callback functions. If it is unable to do so, the original value + /// is returned. + void tryDecomposeValue(OpBuilder &, Location, Type, Value, + SmallVectorImpl &); + + /// This method registers a callback function that will be called to decompose + /// a value of a certain type into several values. + template ::template arg_t<2>> + void addDecomposeValueConversion(FnT &&callback) { + decomposeValueConversions.emplace_back( + wrapDecomposeValueConversionCallback(std::forward(callback))); + } + + /// This method returns ResultConversionKind for the input Type. + ResultConversionKind getResultConversionKind(Type); + + /// This method sets ResultConversionKind for the template type `T`. + template + void setResultConversionKind(ResultConversionKind kind) { + assert( + (kind != AppendToArgumentsList || std::is_same::value) && + "Only the memref typed values can be set to be appended to the " + "function argument list at the moment"); + resultTypeConversions.emplace_back( + [&](Type type) -> Optional { + if (T derivedType = type.dyn_cast()) + return kind; + return llvm::None; + }); + } + +private: + using DecomposeValueConversionCallFn = std::function( + OpBuilder &, Location, Type, Value, SmallVectorImpl &)>; + + /// Generate a wrapper for the given decompose value conversion callback. + template + DecomposeValueConversionCallFn + wrapDecomposeValueConversionCallback(FnT &&callback) { + return [callback = std::forward(callback)]( + OpBuilder &builder, Location loc, Type type, Value value, + SmallVectorImpl &newValues) -> Optional { + if (T derivedType = type.dyn_cast()) + return callback(builder, loc, derivedType, value, newValues); + return llvm::None; + }; + } + + SmallVector(Type)>, 2> + resultTypeConversions; + SmallVector decomposeValueConversions; +}; + /// Helper conversion pattern that encapsulates a BufferAssignmentPlacer /// instance. Sample usage: /// class CustomConversionPattern : public @@ -68,43 +136,20 @@ public: explicit BufferAssignmentOpConversionPattern( MLIRContext *context, BufferAssignmentPlacer *bufferAssignment = nullptr, - TypeConverter *converter = nullptr, PatternBenefit benefit = 1) + BufferAssignmentTypeConverter *converter = nullptr, + PatternBenefit benefit = 1) : OpConversionPattern(context, benefit), bufferAssignment(bufferAssignment), converter(converter) {} protected: BufferAssignmentPlacer *bufferAssignment; - TypeConverter *converter; -}; - -/// A helper type converter class for using inside Buffer Assignment operation -/// conversion patterns. The default constructor keeps all the types intact -/// except for the ranked-tensor types which is converted to memref types. -class BufferAssignmentTypeConverter : public TypeConverter { -public: - BufferAssignmentTypeConverter(); - - /// A helper function to check if `type` has been converted from non-memref - /// type to memref. - static bool isConvertedMemref(Type type, Type before); + BufferAssignmentTypeConverter *converter; }; -namespace detail { - -/// Converts the signature of the function based on whether the function is -/// allowed to return memref typed results or not using -/// `allowMemrefFunctionResults` parameter. If this option is false, then it -/// adds an extra function argument as an output buffer for each function result -/// which is going to be a memref type only after type conversion. The -/// other function result types remain unchanged. If -/// `allowMemrefFunctionResults` is true, the types are converted in place. -/// Any changes in function signature need to be applied -/// to return and caller operations. `BufferAssignmentReturnOpConverter` and -/// `BufferAssignmentCallOpConverter` are two helper function that match the -/// return and caller operation with the new function signature. Furthermore, -/// `BufferAssignmentTypeConverter` is a helper `TypeConverter` for converting -/// tensor typed values to memref typed ones. -template +/// Converts the signature of the function using BufferAssignmentTypeConverter. +/// Each result type of the function is kept as a function result or appended to +/// the function arguments list based on ResultConversionKind for the converted +/// result type. class BufferAssignmentFuncOpConverter : public BufferAssignmentOpConversionPattern { public: @@ -112,58 +157,16 @@ FuncOp>::BufferAssignmentOpConversionPattern; /// Performs the actual signature rewriting step. - LogicalResult - matchAndRewrite(mlir::FuncOp funcOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - if (!converter) - return funcOp.emitError("The type converter has not been defined for " - "BufferAssignmentFuncOpConverter"); - auto funcType = funcOp.getType(); - - // Convert function arguments using the provided TypeConverter. - TypeConverter::SignatureConversion conversion(funcType.getNumInputs()); - for (auto argType : llvm::enumerate(funcType.getInputs())) - conversion.addInputs(argType.index(), - converter->convertType(argType.value())); - - // If allowMemrefFunctionResults is false and a function result type is not - // a memref but it would be a memref after type conversion, a new argument - // should be appended to the function arguments list for this result. - // Otherwise, it remains unchanged as a function result. - SmallVector newResultTypes; - newResultTypes.reserve(funcOp.getNumResults()); - for (Type resType : funcType.getResults()) { - Type convertedType = converter->convertType(resType); - if (!allowMemrefFunctionResults && - BufferAssignmentTypeConverter::isConvertedMemref(convertedType, - resType)) - conversion.addInputs(convertedType); - else - newResultTypes.push_back(convertedType); - } - if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter, - &conversion))) - return failure(); - - // Update the signature of the function. - rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), - newResultTypes)); - }); - return success(); - } + LogicalResult matchAndRewrite(mlir::FuncOp, ArrayRef, + ConversionPatternRewriter &) const; }; /// Rewrites the `ReturnOp` to conform with the changed function signature. -/// if allowMemrefFunctionResults is false, operands that correspond to return -/// values and have been rewritten from illegal typed results to memref -/// arguments are dropped. In their place, a corresponding copy operation from -/// the operand to the output function argument is inserted. Otherwise, the -/// memref typed operands are returned. -/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter, -/// allowMemrefFunctionResults must be set/unset for both. +/// Operands that correspond to return values and their types have been set to +/// AppendToArgumentsList are dropped. In their place, a corresponding copy +/// operation from the operand to the target function argument is inserted. template + typename CopyOpTy> class BufferAssignmentReturnOpConverter : public BufferAssignmentOpConversionPattern { public: @@ -174,44 +177,48 @@ LogicalResult matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - // If the memref typed results can be returned as function results, the new - // `ReturnOp` should only return the type converted operands. - if (allowMemrefFunctionResults) { - rewriter.replaceOpWithNewOp(returnOp, operands); - return success(); + if (!this->converter) + return returnOp.emitError("The type converter has not been defined for " + "BufferAssignmentCallOpConverter"); + Location loc = returnOp.getLoc(); + OpBuilder builder(returnOp); + SmallVector newOperands, needCopyOperands; + + // Split the operands whether they need a copy operation or they remain as + // operands of the return operation. If an operand is decomposable and a + // decompose callback function has been provided by the user, it will be + // unpacked. + for (Value operand : operands) { + SmallVector values; + this->converter->tryDecomposeValue(builder, loc, operand.getType(), + operand, values); + for (Value value : values) { + auto kind = this->converter->getResultConversionKind(value.getType()); + if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult) + newOperands.push_back(value); + else // kind = BufferAssignmentTypeConverter::AppendToArgumentsList + needCopyOperands.push_back(value); + } } - // Split the operands by their kinds whether they are converted memref or - // not. - SmallVector needCopyOperands, newOperands; - unsigned operandsSize = operands.size(); - needCopyOperands.reserve(operandsSize); - newOperands.reserve(operandsSize); - for (auto operand : llvm::enumerate(operands)) - if (BufferAssignmentTypeConverter::isConvertedMemref( - operand.value().getType(), - returnOp.getOperand(operand.index()).getType())) - needCopyOperands.push_back(operand.value()); - else - newOperands.push_back(operand.value()); - + // Insert Copy operations instead for the operands that have been removed + // from operand list and appended to the function arguments list. Block &entryBlock = returnOp.getParentRegion()->front(); - unsigned numFuncArgs = entryBlock.getNumArguments(); - - // Find the index of the first destination buffer. - assert(needCopyOperands.size() <= numFuncArgs && - "The number of operands of return operation is more than the " - "number of function arguments."); - unsigned destArgNum = numFuncArgs - needCopyOperands.size(); + int numFuncArgs = entryBlock.getNumArguments(); + int destArgNum = numFuncArgs - needCopyOperands.size(); rewriter.setInsertionPoint(returnOp); for (Value operand : needCopyOperands) { - // Insert a `CopyOp` for each converted memref-type operand. - rewriter.create(returnOp.getLoc(), operand, + assert(destArgNum >= 0 && destArgNum < numFuncArgs && + "The number of operands that need Copy " + "operations is more than the " + "number of target function arguments"); + if (!operand.getType().isa()) + return returnOp.emitError( + "Cannot insert a copy for a non-Memref typed value"); + rewriter.create(loc, operand, entryBlock.getArgument(destArgNum)); ++destArgNum; } - - // Insert the new target Return operation. rewriter.replaceOpWithNewOp(returnOp, newOperands); return success(); } @@ -219,94 +226,32 @@ /// Rewrites the `CallOp` to match its operands and results with the signature /// of the callee after rewriting the callee with -/// BufferAssignmentFuncOpConverter. If allowMemrefFunctionResults is false, a -/// buffer is allocated as an output buffer only for each memref typed result -/// that has been rewritten. The new allocated buffer is passed through the -/// operands list of the new `CallOp`. -/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter, -/// allowMemrefFunctionResults must be set/unset for both. -template +/// BufferAssignmentFuncOpConverter. class BufferAssignmentCallOpConverter : public BufferAssignmentOpConversionPattern { public: using BufferAssignmentOpConversionPattern< CallOp>::BufferAssignmentOpConversionPattern; - LogicalResult - matchAndRewrite(CallOp callOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - if (!converter) - return callOp.emitError("The type converter has not been defined for " - "BufferAssignmentCallOpConverter"); - Location loc = callOp.getLoc(); - - // If the memref typed results can be returned as function results, there is - // no need to create output buffers. It is only required to convert the type - // of operands and results in place for creating the new `CallOp`. - if (allowMemrefFunctionResults) { - SmallVector resultTypes; - resultTypes.reserve(callOp.getNumResults()); - for (Type type : callOp.getResultTypes()) - resultTypes.push_back(converter->convertType(type)); - rewriter.replaceOpWithNewOp(callOp, callOp.getCallee(), - resultTypes, operands); - return success(); - } - - SmallVector newOperands, replacingValues; - SmallVector newResultTypes; - unsigned numResults = callOp.getNumResults(); - newOperands.reserve(numResults + operands.size()); - newOperands.append(operands.begin(), operands.end()); - newResultTypes.reserve(numResults); - replacingValues.reserve(numResults); - - // For each memref result of `CallOp` which has not been a memref before - // the type conversion, a new buffer is allocated and passed to the operands - // list of the new `CallOp`. Otherwise, it remains as a caller result. - for (Value result : callOp.getResults()) { - Type currType = result.getType(); - Type newType = converter->convertType(result.getType()); - if (BufferAssignmentTypeConverter::isConvertedMemref(newType, currType)) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.restoreInsertionPoint(bufferAssignment->computeAllocPosition( - result.dyn_cast())); - Value alloc = - rewriter.create(loc, newType.dyn_cast()); - newOperands.push_back(alloc); - replacingValues.push_back(alloc); - } else { - newResultTypes.push_back(currType); - - // No replacing is required. - replacingValues.push_back(nullptr); - } - } - - // Creating the new `CallOp`. - rewriter.create(loc, callOp.getCallee(), newResultTypes, - newOperands); - - // Replacing the results of the old `CallOp`. - rewriter.replaceOp(callOp, replacingValues); - return success(); - } + /// Performs the actual rewriting step. + LogicalResult matchAndRewrite(CallOp, ArrayRef, + ConversionPatternRewriter &) const; }; -} // end namespace detail /// Populates `patterns` with the conversion patterns of buffer /// assignment. template + typename CopyOpTy> static void populateWithBufferAssignmentOpConversionPatterns( MLIRContext *context, BufferAssignmentPlacer *placer, - TypeConverter *converter, OwningRewritePatternList *patterns) { + BufferAssignmentTypeConverter *converter, + OwningRewritePatternList *patterns) { // clang-format off patterns->insert< - detail::BufferAssignmentCallOpConverter, - detail::BufferAssignmentFuncOpConverter, - detail::BufferAssignmentReturnOpConverter - + BufferAssignmentCallOpConverter, + BufferAssignmentFuncOpConverter, + BufferAssignmentReturnOpConverter + >(context, placer, converter); // clang-format on } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp @@ -100,11 +100,11 @@ /// tensors to buffers. static void populateConvertLinalgOnTensorsToBuffersPattern( MLIRContext *context, BufferAssignmentPlacer *placer, - TypeConverter *converter, OwningRewritePatternList *patterns) { + BufferAssignmentTypeConverter *converter, + OwningRewritePatternList *patterns) { populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp, - /*allowMemrefFunctionResults=*/false>(context, placer, converter, - patterns); + mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer, + converter, patterns); patterns->insert(context, placer, converter); } @@ -141,6 +141,9 @@ converter.isLegal(&funcOp.getBody()); }); + converter.setResultConversionKind( + BufferAssignmentTypeConverter::AppendToArgumentsList); + // Walk over all the functions to apply buffer assignment. getOperation().walk([&](FuncOp function) -> WalkResult { OwningRewritePatternList patterns; diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp --- a/mlir/lib/Transforms/BufferPlacement.cpp +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -710,9 +710,208 @@ }); } -/// Checks if `type` has been converted from non-memref type to memref. -bool BufferAssignmentTypeConverter::isConvertedMemref(Type type, Type before) { - return type.isa() && !before.isa(); +/// This method tries to decompose a value of a certain type using provided +/// decompose callback functions. If it is unable to do so, the original value +/// is returned. +void BufferAssignmentTypeConverter::tryDecomposeValue( + OpBuilder &builder, Location loc, Type type, Value value, + SmallVectorImpl &results) { + for (auto conversion : decomposeValueConversions) { + auto res = conversion(builder, loc, type, value, results); + if (res != llvm::None) + return; + } + results.push_back(value); +} + +/// This method returns ResultConversionKind for the input type. +BufferAssignmentTypeConverter::ResultConversionKind +BufferAssignmentTypeConverter::getResultConversionKind(Type type) { + for (auto conversion : resultTypeConversions) { + auto res = conversion(type); + if (res != llvm::None) + return res.getValue(); + } + return KeepAsFunctionResult; +} + +//===----------------------------------------------------------------------===// +// BufferAssignmentFuncOpConverter +//===----------------------------------------------------------------------===// + +/// Performs the actual function signature rewriting step. +LogicalResult BufferAssignmentFuncOpConverter::matchAndRewrite( + mlir::FuncOp funcOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + if (!converter) + return funcOp.emitError("The type converter has not been defined for " + "BufferAssignmentFuncOpConverter"); + + auto funcType = funcOp.getType(); + + // Convert function arguments using the provided TypeConverter. + TypeConverter::SignatureConversion conversion(funcType.getNumInputs()); + for (auto argType : llvm::enumerate(funcType.getInputs())) { + SmallVector convertedTypes; + converter->convertType(argType.value(), convertedTypes); + conversion.addInputs(argType.index(), convertedTypes); + } + + // Convert the result types of the function. + SmallVector newResultTypes; + newResultTypes.reserve(funcOp.getNumResults()); + for (Type resultType : funcType.getResults()) { + SmallVector convertedTypes; + converter->convertType(resultType, convertedTypes); + for (auto type : convertedTypes) { + auto kind = converter->getResultConversionKind(type); + if (kind == BufferAssignmentTypeConverter::AppendToArgumentsList) + conversion.addInputs(type); + else // kind = BufferAssignmentTypeConverter::KeepAsFunctionResult + newResultTypes.push_back(type); + } + } + if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter, + &conversion))) + return failure(); + + // Update the signature of the function. + rewriter.updateRootInPlace(funcOp, [&] { + funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), + newResultTypes)); + }); + return success(); +} + +//===----------------------------------------------------------------------===// +// BufferAssignmentCallOpConverter +//===----------------------------------------------------------------------===// + +/// Performs the actual rewriting step. +LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite( + CallOp callOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + + // This class represents a mapping from a result to a list of values and some + // results that have not yet constructed. Instead, the indices of these + // results in the operation that will be constructed are known. They will be + // replaced with the actual values when they are available. The order of + // adding to this mapping is important. + class ResultMapping { + public: + ResultMapping() { order = 0; }; + + /// Add an available value to the mapping. + void addMapping(Value value) { + toValuesMapping.push_back({order++, value}); + } + + /// Add the index of unavailble result value to the mapping. + void addMapping(unsigned index) { + toIndicesMapping.push_back({order++, index}); + } + + /// This method returns the mapping values list. The unknown result values + /// that only their indicies are available are replaced with their values. + void getMappingValues(ValueRange valuesToReplaceIndices, + SmallVectorImpl &values) { + // Append available values to the list. + SmallVector, 2> res(toValuesMapping.begin(), + toValuesMapping.end()); + // Replace the indices with the actual values. + llvm::for_each( + toIndicesMapping, [&](std::pair entry) { + assert(entry.second < valuesToReplaceIndices.size() && + "The value index is out of range."); + res.push_back({entry.first, valuesToReplaceIndices[entry.second]}); + }); + // Sort the values based on their adding orders. + llvm::sort(res, [](const std::pair &v1, + const std::pair &v2) { + return v1.first < v2.first; + }); + // Fill the values. + llvm::for_each(res, [&](auto entry) { values.push_back(entry.second); }); + } + + private: + int order; + SmallVector, 2> toValuesMapping; + SmallVector, 2> toIndicesMapping; + }; + + if (!converter) + return callOp.emitError("The type converter has not been defined for " + "BufferAssignmentCallOpConverter"); + Location loc = callOp.getLoc(); + OpBuilder builder(callOp); + SmallVector newOperands; + + // Create the operands list of the new `CallOp`. It unpacks the decomposable + // values if a decompose callback function has been provided by the user. + for (auto operand : operands) { + SmallVector values; + this->converter->tryDecomposeValue(builder, loc, operand.getType(), operand, + values); + newOperands.append(values.begin(), values.end()); + } + + // Create the new result types for the new `CallOp` and a mapping from the old + // result to new value(s). + SmallVector newResultTypes; + SmallVector mappings; + mappings.resize(callOp.getNumResults()); + for (auto result : llvm::enumerate(callOp.getResults())) { + SmallVector convertedTypes; + converter->convertType(result.value().getType(), convertedTypes); + auto &resultMapping = mappings[result.index()]; + for (Type type : convertedTypes) { + auto kind = converter->getResultConversionKind(type); + if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult) { + newResultTypes.push_back(type); + // The result value is not yet available. Its index is kept and it is + // replaced with the actual value of the new `CallOp` later. + resultMapping.addMapping(newResultTypes.size() - 1); + } else { // kind = BufferAssignmentTypeConverter::AppendToArgumentsList + OpBuilder::InsertionGuard guard(rewriter); + rewriter.restoreInsertionPoint( + bufferAssignment->computeAllocPosition(result.value())); + MemRefType memref = type.dyn_cast(); + if (!memref) + return callOp.emitError("Cannot allocate for a non-Memref type"); + Value alloc = rewriter.create(loc, memref); + newOperands.push_back(alloc); + resultMapping.addMapping(alloc); + } + } + } + + CallOp newCallOp = rewriter.create(loc, callOp.getCallee(), + newResultTypes, newOperands); + + // Build a replacing value for each result to replace it uses. If a result has + // multiple mapping values, it needs to be packed to a single value. + OpBuilder nextBuilder(callOp.getOperation()->getNextNode()); + SmallVector replacedValues; + replacedValues.reserve(callOp.getNumResults()); + for (unsigned i = 0; i < callOp.getNumResults(); ++i) { + SmallVector valuesToPack; + mappings[i].getMappingValues(newCallOp.getResults(), valuesToPack); + if (valuesToPack.empty()) + // No replacement is required. + replacedValues.push_back(nullptr); + else if (valuesToPack.size() == 1) + replacedValues.push_back(valuesToPack.front()); + else { + // Values need to be packed using callback function. The same callback + // that is used for materializeArgumentConversion is used for packing. + Value packed = converter->materializeArgumentConversion( + nextBuilder, loc, callOp.getType(i), valuesToPack); + replacedValues.push_back(packed); + } + } + rewriter.replaceOp(callOp, replacedValues); + return success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir --- a/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir +++ b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir @@ -111,7 +111,58 @@ // CHECK: %[[Y:.*]]:2 = call @callee(%[[X]]#0) // CHECK: return %[[Y]]#0 +// ----- +// CHECK-LABEL: func @callee +func @callee(%arg0: tuple,i1, tensor<5xf32>>) -> (tuple,i1, tensor<5xf32>>){ + return %arg0 : tuple,i1, tensor<5xf32>> +} +// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>) +// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) +// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) +// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32} +// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]] +// CHECK-LABEL: func @caller +func @caller(%arg0: tuple,i1, tensor<5xf32>>) -> tuple,i1, tensor<5xf32>>{ + %x0 = call @callee(%arg0) : (tuple,i1, tensor<5xf32>>) -> (tuple,i1, tensor<5xf32>>) + %y0 = call @callee(%x0) : (tuple,i1, tensor<5xf32>>) -> (tuple,i1, tensor<5xf32>>) + return %y0 : tuple,i1, tensor<5xf32>> +} +// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>) +// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) +// CHECK-NEXT: %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) +// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 2 : i32} +// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]) +// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>) +// CHECK-NEXT: %[[RESULT_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2) +// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 2 : i32} +// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]) +// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>) +// CHECK-NEXT: %[[RETURN_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2) +// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 2 : i32} +// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]] +// ----- +// CHECK-LABEL: func @test +func @test(%arg0: tuple, %arg1: tensor<10xf32>, %arg2: tuple>) -> (tuple>, tensor<10xf32>, tuple){ + return %arg2, %arg1, %arg0 : tuple>, tensor<10xf32>, tuple +} +// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<10xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: memref<5xf32> +// CHECK-SAME: (i1, memref<5xf32>, memref<10xf32>, i1, f32) +// CHECK-NEXT: %[[FIRST_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) +// CHECK-NEXT: %[[SECOND_TUPLE:.*]] = "test.make_tuple"(%[[ARG3]], %[[ARG4]]) +// CHECK-NEXT: %[[SECOND_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[SECOND_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: %[[FIRST_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[FIRST_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[SECOND_TUPLE_SECOND_ELEM]], %[[ARG2]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]] diff --git a/mlir/test/Transforms/buffer-placement-preparation.mlir b/mlir/test/Transforms/buffer-placement-preparation.mlir --- a/mlir/test/Transforms/buffer-placement-preparation.mlir +++ b/mlir/test/Transforms/buffer-placement-preparation.mlir @@ -8,31 +8,6 @@ // ----- -// Only tensor typed function result should be converted to memref and move to the -// function arguments list. The other memref function results remain as function -// results. - -#map0 = affine_map<(d0) -> (d0)> - -// CHECK-LABEL: func @memref_in_function_results -func @memref_in_function_results(%arg0: tensor<5xf32>, %arg1: memref<10xf32>) -> (tensor<5xf32>, memref<10xf32>, memref<15xf32>) { - %0 = alloc() : memref<15xf32> - %1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 { - ^bb0(%gen1_arg0: f32): - %tmp1 = exp %gen1_arg0 : f32 - linalg.yield %tmp1 : f32 - }: tensor<5xf32> -> tensor<5xf32> - return %1, %arg1, %0 : tensor<5xf32>, memref<10xf32>, memref<15xf32> -} -// CHECK: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>, %[[RESULT:.*]]: memref<5xf32>) -// CHECK-SAME: (memref<10xf32>, memref<15xf32>) -// CHECK: %[[FIRST_ALLOC:.*]] = alloc() -// CHECK: %[[LINALG_ALLOC:.*]] = alloc() -// CHECK: linalg.copy(%[[LINALG_ALLOC]], %[[RESULT]]) -// CHECK: return %[[ARG1]], %[[FIRST_ALLOC]] - -// ----- - // CHECK-LABEL: func @no_signature_conversion_is_needed func @no_signature_conversion_is_needed(%arg0: memref<4x8xf32>) { return @@ -264,12 +239,11 @@ %buff = alloc() : memref<2xf32> return %arg1, %buff : tensor<5xf32>, memref<2xf32> } -// CHECK: (%[[CALLEE_ARG:.*]]: memref<5xf32>, %[[CALLEE_RESULT:.*]]: memref<5xf32>) -// CHECK-SAME: memref<2xf32> -// CHECK: %[[ALLOC:.*]] = alloc() +// CHECK-SAME: (%[[CALLEE_ARG:.*]]: memref<5xf32>, %[[CALLEE_RESULT:.*]]: memref<5xf32>, %[[CALLEE_RESULT_2:.*]]: memref<2xf32>) +// CHECK-NEXT: %[[ALLOC:.*]] = alloc() // CHECK: linalg.copy(%[[CALLEE_ARG]], %[[CALLEE_RESULT]]) -// CHECK: return %[[ALLOC]] - +// CHECK: linalg.copy(%[[ALLOC]], %[[CALLEE_RESULT_2]]) +// CHECK: return // CHECK-LABEL: func @caller func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> { @@ -279,14 +253,85 @@ } // CHECK: (%[[CALLER_ARG:.*]]: memref<5xf32>, %[[CALLER_RESULT:.*]]: memref<5xf32>) // CHECK: %[[X0:.*]] = alloc() -// CHECK: %[[X1:.*]] = call @callee(%[[CALLER_ARG]], %[[X0]]) +// CHECK: %[[ALLOC0:.*]] = alloc() +// CHECK: call @callee(%[[CALLER_ARG]], %[[X0]], %[[ALLOC0]]) // CHECK: %[[Y0:.*]] = alloc() -// CHECK: %[[Y1:.*]] = call @callee(%[[X0]], %[[Y0]]) +// CHECK: %[[ALLOC1:.*]] = alloc() +// CHECK: call @callee(%[[X0]], %[[Y0]], %[[ALLOC1]]) // CHECK: linalg.copy(%[[Y0]], %[[CALLER_RESULT]]) // CHECK: return +// ----- + // CHECK-LABEL: func @func_with_unranked_arg func @func_with_unranked_arg(%arg0: tensor<*xf32>) { return } // CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) + +// ----- + +// CHECK-LABEL: func @callee +func @callee(%arg0: tuple,i1, tensor<5xf32>>) -> (tuple,i1, tensor<5xf32>>){ + return %arg0 : tuple,i1, tensor<5xf32>> +} +// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<2xf32>, %[[RESULT1:.*]]: memref<5xf32>) +// CHECK-SAME: i1 +// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) +// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32} +// CHECK-NEXT: linalg.copy(%[[FIRST_ELEM]], %[[RESULT0]]) +// CHECK-NEXT: linalg.copy(%[[THIRD_ELEM]], %[[RESULT1]]) +// CHECK-NEXT: return %[[SECOND_ELEM]] + + +// CHECK-LABEL: func @caller +func @caller(%arg0: tuple,i1, tensor<5xf32>>) -> tuple,i1, tensor<5xf32>>{ + %x0 = call @callee(%arg0) : (tuple,i1, tensor<5xf32>>) -> (tuple,i1, tensor<5xf32>>) + %y0 = call @callee(%x0) : (tuple,i1, tensor<5xf32>>) -> (tuple,i1, tensor<5xf32>>) + return %y0 : tuple,i1, tensor<5xf32>> +} +// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<2xf32>, %[[RESULT1:.*]]: memref<5xf32>) +// CHECK-SAME: i1 +// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) +// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32} +// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc() +// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc() +// CHECK-NEXT: %[[CALLEE_RESULT:.*]] = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]], %[[FIRST_ALLOC]], %[[SECOND_ALLOC]]) +// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>, memref<2xf32>, memref<5xf32>) -> i1 +// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[FIRST_ALLOC]], %[[CALLEE_RESULT]], %[[SECOND_ALLOC]]) +// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32} +// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc() +// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc() +// CHECK-NEXT: %[[CALLEE_RESULT:.*]] = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]], %[[FIRST_ALLOC]], %[[SECOND_ALLOC]]) +// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>, memref<2xf32>, memref<5xf32>) -> i1 +// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[FIRST_ALLOC]], %[[CALLEE_RESULT]], %[[SECOND_ALLOC]]) +// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32} +// CHECK-NEXT: linalg.copy(%[[FIRST_ELEM]], %[[RESULT0]]) +// CHECK-NEXT: linalg.copy(%[[THIRD_ELEM]], %[[RESULT1]]) +// CHECK-NEXT: return %[[SECOND_ELEM]] + +// ----- + +// CHECK-LABEL: func @test +func @test(%arg0: tuple, %arg1: tensor<10xf32>, %arg2: tuple>) -> (tuple>, tensor<10xf32>, tuple){ + return %arg2, %arg1, %arg0 : tuple>, tensor<10xf32>, tuple +} +// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<10xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<5xf32>, %[[RESULT1:.*]]: memref<10xf32> +// CHECK-SAME: (i1, i1, f32) +// CHECK-NEXT: %[[FIRST_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) +// CHECK-NEXT: %[[SECOND_TUPLE:.*]] = "test.make_tuple"(%[[ARG3]], %[[ARG4]]) +// CHECK-NEXT: %[[SECOND_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[SECOND_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: %[[FIRST_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 0 : i32} +// CHECK-NEXT: %[[FIRST_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 1 : i32} +// CHECK-NEXT: linalg.copy(%[[SECOND_TUPLE_SECOND_ELEM]], %[[RESULT0]]) +// CHECK-NEXT: linalg.copy(%[[ARG2]], %[[RESULT1]]) +// CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]] diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1473,4 +1473,27 @@ }]; } +def GetTupleElementOp: TEST_Op<"get_tuple_element"> { + let description = [{ + Test op that returns the i-th element of the tuple. + }]; + + let arguments = (ins + TupleOf<[AnyType]>, + I32Attr:$index + ); + let results = (outs AnyType); +} + +def MakeTupleOp: TEST_Op<"make_tuple"> { + let description = [{ + Test op that creates a tuple value from a list of values. + }]; + + let arguments = (ins + Variadic:$inputs + ); + let results = (outs TupleOf<[AnyType]>); +} + #endif // TEST_OPS diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -11,6 +11,8 @@ // //===----------------------------------------------------------------------===// +#include "TestDialect.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/IR/Function.h" #include "mlir/IR/Operation.h" @@ -109,10 +111,11 @@ void populateTensorLinalgToBufferLinalgConversionPattern( MLIRContext *context, BufferAssignmentPlacer *placer, - TypeConverter *converter, OwningRewritePatternList *patterns) { + BufferAssignmentTypeConverter *converter, + OwningRewritePatternList *patterns) { populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp, - allowMemrefFunctionResults>(context, placer, converter, patterns); + mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer, + converter, patterns); patterns->insert(context, placer, converter); } @@ -123,6 +126,8 @@ // Mark all Standard operations legal. target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); // Mark all Linalg operations illegal as long as they work on tensors. auto isLegalOperation = [&](Operation *op) { @@ -145,6 +150,42 @@ converter.isLegal(&funcOp.getBody()); }); + converter.setResultConversionKind( + allowMemrefFunctionResults + ? BufferAssignmentTypeConverter::KeepAsFunctionResult + : BufferAssignmentTypeConverter::AppendToArgumentsList); + + converter.addConversion([&](TupleType tupleType, + SmallVectorImpl &types) { + llvm::for_each(tupleType.getTypes(), + [&](Type type) { converter.convertType(type, types); }); + return success(); + }); + + converter.addArgumentMaterialization( + [](OpBuilder &builder, TupleType resultType, ValueRange inputs, + Location loc) -> Optional { + if (inputs.size() == 1) + return llvm::None; + TypeRange TypeRange = inputs.getTypes(); + SmallVector types(TypeRange.begin(), TypeRange.end()); + TupleType tuple = TupleType::get(types, builder.getContext()); + mlir::Value value = builder.create(loc, tuple, inputs); + return value; + }); + + converter.addDecomposeValueConversion( + [](OpBuilder &builder, Location loc, TupleType resultType, + mlir::Value value, SmallVectorImpl &values) { + for (unsigned i = 0; i < resultType.size(); ++i) { + mlir::Value res = builder.create( + loc, resultType.getType(i), value, + builder.getI32IntegerAttr(i)); + values.push_back(res); + } + return success(); + }); + // Walk over all the functions to apply buffer assignment. this->getOperation().walk([&](FuncOp function) -> WalkResult { OwningRewritePatternList patterns;