diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.h @@ -0,0 +1,90 @@ +//===- DecomposeCallGraphTypes.h - CG type decompositions -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Conversion patterns for decomposing types along call graph edges. That is, +// decomposing types for calls, returns, and function args. +// +// TODO: Make this handle dialect-defined functions, calls, and returns. +// Currently, the generic interfaces aren't sophisticated enough for the +// types of mutations that we are doing here. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_STANDARDOPS_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H +#define MLIR_DIALECT_STANDARDOPS_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +/// This class provides a hook that expands one Value into multiple Value's, +/// with a TypeConverter-inspired callback registration mechanism. +/// +/// For folks that are familiar with the dialect conversion framework / +/// TypeConverter, this is effectively the inverse of a source/argument +/// materialization. A target materialization is not what we want here because +/// it always produces a single Value, but in this case the whole point is to +/// decompose a Value into multiple Value's. +/// +/// The reason we need this inverse is easily understood by looking at what we +/// need to do for decomposing types for a return op. When converting a return +/// op, the dialect conversion framework will give the list of converted +/// operands, and will ensure that each converted operand, even if it expanded +/// into multiple types, is materialized as a single result. We then need to +/// undo that materialization to a single result, which we do with the +/// decomposeValue hooks registered on this object. +/// +/// TODO: Eventually, the type conversion infra should have this hook built-in. +/// See +/// https://llvm.discourse.group/t/extending-type-conversion-infrastructure/779/2 +class ValueDecomposer { +public: + /// This method tries to decompose a value of a certain type using provided + /// decompose callback functions. If it is unable to do so, the original value + /// is returned. + void decomposeValue(OpBuilder &, Location, Type, Value, + SmallVectorImpl &); + + /// This method registers a callback function that will be called to decompose + /// a value of a certain type into 0, 1, or multiple values. + template ::template arg_t<2>> + void addDecomposeValueConversion(FnT &&callback) { + decomposeValueConversions.emplace_back( + wrapDecomposeValueConversionCallback(std::forward(callback))); + } + +private: + using DecomposeValueConversionCallFn = std::function( + OpBuilder &, Location, Type, Value, SmallVectorImpl &)>; + + /// Generate a wrapper for the given decompose value conversion callback. + template + DecomposeValueConversionCallFn + wrapDecomposeValueConversionCallback(FnT &&callback) { + return [callback = std::forward(callback)]( + OpBuilder &builder, Location loc, Type type, Value value, + SmallVectorImpl &newValues) -> Optional { + if (T derivedType = type.dyn_cast()) + return callback(builder, loc, derivedType, value, newValues); + return llvm::None; + }; + } + + SmallVector decomposeValueConversions; +}; + +/// Populates the patterns needed to drive the conversion process for +/// decomposing call graph types with the given `ValueDecomposer`. +void populateDecomposeCallGraphTypesPatterns( + MLIRContext *context, TypeConverter &typeConverter, + ValueDecomposer &decomposer, OwningRewritePatternList &patterns); + +} // end namespace mlir + +#endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H diff --git a/mlir/include/mlir/Transforms/Bufferize.h b/mlir/include/mlir/Transforms/Bufferize.h --- a/mlir/include/mlir/Transforms/Bufferize.h +++ b/mlir/include/mlir/Transforms/Bufferize.h @@ -15,13 +15,8 @@ // // Bufferization conversion patterns should generally use the ordinary // conversion pattern classes (e.g. OpConversionPattern). A TypeConverter -// (accessible with getTypeConverter()) available on such patterns is sufficient -// for most cases (if needed at all). -// -// But some patterns require access to the extra functions on -// BufferizeTypeConverter that don't exist on the base TypeConverter class. For -// those cases, BufferizeConversionPattern and its related classes should be -// used, which provide access to a BufferizeTypeConverter directly. +// (accessible with getTypeConverter()) is available if needed for converting +// types. // //===----------------------------------------------------------------------===// @@ -39,79 +34,11 @@ namespace mlir { -/// A helper type converter class for using inside Buffer Assignment operation -/// conversion patterns. The default constructor keeps all the types intact -/// except for the ranked-tensor types which is converted to memref types. +/// A helper type converter class that automatically populates the relevant +/// materializations and type conversions for bufferization. class BufferizeTypeConverter : public TypeConverter { public: BufferizeTypeConverter(); - - /// This method tries to decompose a value of a certain type using provided - /// decompose callback functions. If it is unable to do so, the original value - /// is returned. - void tryDecomposeValue(OpBuilder &, Location, Type, Value, - SmallVectorImpl &); - - /// This method tries to decompose a type using provided decompose callback - /// functions. If it is unable to do so, the original type is returned. - void tryDecomposeType(Type, SmallVectorImpl &); - - /// This method registers a callback function that will be called to decompose - /// a value of a certain type into several values. - template ::template arg_t<2>> - void addDecomposeValueConversion(FnT &&callback) { - decomposeValueConversions.emplace_back( - wrapDecomposeValueConversionCallback(std::forward(callback))); - } - - /// This method registers a callback function that will be called to decompose - /// a type into several types. - template ::template arg_t<0>> - void addDecomposeTypeConversion(FnT &&callback) { - auto wrapper = - wrapDecomposeTypeConversionCallback(std::forward(callback)); - decomposeTypeConversions.emplace_back(wrapper); - addConversion(std::forward(callback)); - } - -private: - using DecomposeValueConversionCallFn = std::function( - OpBuilder &, Location, Type, Value, SmallVectorImpl &)>; - - using DecomposeTypeConversionCallFn = - std::function(Type, SmallVectorImpl &)>; - - /// Generate a wrapper for the given decompose value conversion callback. - template - DecomposeValueConversionCallFn - wrapDecomposeValueConversionCallback(FnT &&callback) { - return [callback = std::forward(callback)]( - OpBuilder &builder, Location loc, Type type, Value value, - SmallVectorImpl &newValues) -> Optional { - if (T derivedType = type.dyn_cast()) - return callback(builder, loc, derivedType, value, newValues); - return llvm::None; - }; - } - - /// Generate a wrapper for the given decompose type conversion callback. - template - DecomposeTypeConversionCallFn - wrapDecomposeTypeConversionCallback(FnT &&callback) { - return [callback = std::forward(callback)]( - Type type, - SmallVectorImpl &results) -> Optional { - T derivedType = type.dyn_cast(); - if (!derivedType) - return llvm::None; - return callback(derivedType, results); - }; - } - - SmallVector decomposeValueConversions; - SmallVector decomposeTypeConversions; }; /// Marks ops used by bufferization for type conversion materializations as @@ -132,104 +59,6 @@ MLIRContext *context, BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns); -/// Helper conversion pattern that encapsulates a BufferizeTypeConverter -/// instance. -template -class BufferizeOpConversionPattern : public OpConversionPattern { -public: - explicit BufferizeOpConversionPattern(MLIRContext *context, - BufferizeTypeConverter &converter, - PatternBenefit benefit = 1) - : OpConversionPattern(context, benefit), converter(converter) {} - -protected: - BufferizeTypeConverter &converter; -}; - -/// Helper conversion pattern that encapsulates a BufferizeTypeConverter -/// instance and that operates on Operation* to be compatible with OpInterfaces. -/// This allows avoiding to instantiate N patterns for ops that can be subsumed -/// by a single op interface (e.g. Linalg named ops). -class BufferizeConversionPattern : public ConversionPattern { -public: - explicit BufferizeConversionPattern(MLIRContext *context, - BufferizeTypeConverter &converter, - PatternBenefit benefit = 1) - : ConversionPattern(benefit, converter, MatchAnyOpTypeTag()), - converter(converter) {} - -protected: - BufferizeTypeConverter &converter; -}; - -/// Converts the signature of the function using BufferizeTypeConverter. -/// Each result type of the function is kept as a function result or appended to -/// the function arguments list based on ResultConversionKind for the converted -/// result type. -class BufferizeFuncOpConverter : public BufferizeOpConversionPattern { -public: - using BufferizeOpConversionPattern::BufferizeOpConversionPattern; - - /// Performs the actual signature rewriting step. - LogicalResult matchAndRewrite(mlir::FuncOp, ArrayRef, - ConversionPatternRewriter &) const override; -}; - -/// Rewrites the `ReturnOp` to conform with the changed function signature. -/// Operands that correspond to return values and their types have been set to -/// AppendToArgumentsList are dropped. In their place, a corresponding copy -/// operation from the operand to the target function argument is inserted. -template -class BufferizeReturnOpConverter - : public BufferizeOpConversionPattern { -public: - using BufferizeOpConversionPattern< - ReturnOpSourceTy>::BufferizeOpConversionPattern; - - /// Performs the actual return-op conversion step. - LogicalResult - matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - SmallVector newOperands; - for (auto operand : operands) - this->converter.tryDecomposeValue( - rewriter, returnOp.getLoc(), operand.getType(), operand, newOperands); - rewriter.replaceOpWithNewOp(returnOp, newOperands); - return success(); - } -}; - -/// Rewrites the `CallOp` to match its operands and results with the signature -/// of the callee after rewriting the callee with -/// BufferizeFuncOpConverter. -class BufferizeCallOpConverter : public BufferizeOpConversionPattern { -public: - using BufferizeOpConversionPattern::BufferizeOpConversionPattern; - - /// Performs the actual rewriting step. - LogicalResult matchAndRewrite(CallOp, ArrayRef, - ConversionPatternRewriter &) const override; -}; - -/// Populates `patterns` with the conversion patterns of buffer -/// assignment. -template -static void -populateWithBufferizeOpConversionPatterns(MLIRContext *context, - BufferizeTypeConverter &converter, - OwningRewritePatternList &patterns) { - // clang-format off - patterns.insert< - BufferizeCallOpConverter, - BufferizeFuncOpConverter, - BufferizeReturnOpConverter - - >(context, converter); - // clang-format on -} - /// A simple analysis that detects allocation operations. class BufferPlacementAllocs { public: diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms Bufferize.cpp + DecomposeCallGraphTypes.cpp ExpandOps.cpp ExpandTanh.cpp FuncBufferize.cpp diff --git a/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp @@ -0,0 +1,191 @@ +//===- DecomposeCallGraphTypes.cpp - CG type decomposition ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Function.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ValueDecomposer +//===----------------------------------------------------------------------===// + +void ValueDecomposer::decomposeValue(OpBuilder &builder, Location loc, + Type type, Value value, + SmallVectorImpl &results) { + for (auto &conversion : decomposeValueConversions) + if (conversion(builder, loc, type, value, results)) + return; + results.push_back(value); +} + +//===----------------------------------------------------------------------===// +// DecomposeCallGraphTypesOpConversionPattern +//===----------------------------------------------------------------------===// + +namespace { +// Base OpConversionPattern class to make a ValueDecomposer available to +// inherited patterns. +template +class DecomposeCallGraphTypesOpConversionPattern + : public OpConversionPattern { +public: + DecomposeCallGraphTypesOpConversionPattern(TypeConverter &typeConverter, + MLIRContext *context, + ValueDecomposer &decomposer, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + decomposer(decomposer) {} + +protected: + ValueDecomposer &decomposer; +}; +} // namespace + +//===----------------------------------------------------------------------===// +// DecomposeCallGraphTypesForFuncArgs +//===----------------------------------------------------------------------===// + +namespace { +/// Expand function arguments according to the provided TypeConverter and +/// ValueDecomposer. +struct DecomposeCallGraphTypesForFuncArgs + : public DecomposeCallGraphTypesOpConversionPattern { + using DecomposeCallGraphTypesOpConversionPattern:: + DecomposeCallGraphTypesOpConversionPattern; + + LogicalResult + matchAndRewrite(FuncOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto functionType = op.getType(); + + // Convert function arguments using the provided TypeConverter. + TypeConverter::SignatureConversion conversion(functionType.getNumInputs()); + for (auto argType : llvm::enumerate(functionType.getInputs())) { + SmallVector decomposedTypes; + getTypeConverter()->convertType(argType.value(), decomposedTypes); + if (!decomposedTypes.empty()) + conversion.addInputs(argType.index(), decomposedTypes); + } + + if (failed(rewriter.convertRegionTypes(&op.getBody(), *getTypeConverter(), + &conversion))) + return failure(); + + // Update the signature of the function. + SmallVector newResultTypes; + getTypeConverter()->convertTypes(functionType.getResults(), newResultTypes); + rewriter.updateRootInPlace(op, [&] { + op.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), + newResultTypes)); + }); + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// DecomposeCallGraphTypesForReturnOp +//===----------------------------------------------------------------------===// + +namespace { +/// Expand return operands according to the provided TypeConverter and +/// ValueDecomposer. +struct DecomposeCallGraphTypesForReturnOp + : public DecomposeCallGraphTypesOpConversionPattern { + using DecomposeCallGraphTypesOpConversionPattern:: + DecomposeCallGraphTypesOpConversionPattern; + LogicalResult + matchAndRewrite(ReturnOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + SmallVector newOperands; + for (Value operand : operands) + decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(), + operand, newOperands); + rewriter.replaceOpWithNewOp(op, newOperands); + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// DecomposeCallGraphTypesForCallOp +//===----------------------------------------------------------------------===// + +namespace { +/// Expand call op operands and results according to the provided TypeConverter +/// and ValueDecomposer. +struct DecomposeCallGraphTypesForCallOp + : public DecomposeCallGraphTypesOpConversionPattern { + using DecomposeCallGraphTypesOpConversionPattern:: + DecomposeCallGraphTypesOpConversionPattern; + + LogicalResult + matchAndRewrite(CallOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + + // Create the operands list of the new `CallOp`. + SmallVector newOperands; + for (Value operand : operands) + decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(), + operand, newOperands); + + // Create the new result types for the new `CallOp` and track the indices in + // the new call op's results that correspond to the old call op's results. + // + // expandedResultIndices[i] = "list of new result indices that old result i + // expanded to". + SmallVector newResultTypes; + SmallVector, 4> expandedResultIndices; + for (Type resultType : op.getResultTypes()) { + unsigned oldSize = newResultTypes.size(); + getTypeConverter()->convertType(resultType, newResultTypes); + auto &resultMapping = expandedResultIndices.emplace_back(); + for (unsigned i = oldSize, e = newResultTypes.size(); i < e; i++) + resultMapping.push_back(i); + } + + CallOp newCallOp = rewriter.create(op.getLoc(), op.getCallee(), + newResultTypes, newOperands); + + // Build a replacement value for each result to replace its uses. If a + // result has multiple mapping values, it needs to be materialized as a + // single value. + SmallVector replacedValues; + replacedValues.reserve(op.getNumResults()); + for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) { + auto decomposedValues = llvm::to_vector<6>( + llvm::map_range(expandedResultIndices[i], + [&](unsigned i) { return newCallOp.getResult(i); })); + if (decomposedValues.empty()) { + // No replacement is required. + replacedValues.push_back(nullptr); + } else if (decomposedValues.size() == 1) { + replacedValues.push_back(decomposedValues.front()); + } else { + // Materialize a single Value to replace the original Value. + Value materialized = getTypeConverter()->materializeArgumentConversion( + rewriter, op.getLoc(), op.getType(i), decomposedValues); + replacedValues.push_back(materialized); + } + } + rewriter.replaceOp(op, replacedValues); + return success(); + } +}; +} // namespace + +void mlir::populateDecomposeCallGraphTypesPatterns( + MLIRContext *context, TypeConverter &typeConverter, + ValueDecomposer &decomposer, OwningRewritePatternList &patterns) { + patterns.insert(typeConverter, context, + decomposer); +} diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp --- a/mlir/lib/Transforms/Bufferize.cpp +++ b/mlir/lib/Transforms/Bufferize.cpp @@ -41,28 +41,6 @@ }); } -/// This method tries to decompose a value of a certain type using provided -/// decompose callback functions. If it is unable to do so, the original value -/// is returned. -void BufferizeTypeConverter::tryDecomposeValue( - OpBuilder &builder, Location loc, Type type, Value value, - SmallVectorImpl &results) { - for (auto &conversion : decomposeValueConversions) - if (conversion(builder, loc, type, value, results)) - return; - results.push_back(value); -} - -/// This method tries to decompose a type using provided decompose callback -/// functions. If it is unable to do so, the original type is returned. -void BufferizeTypeConverter::tryDecomposeType(Type type, - SmallVectorImpl &types) { - for (auto &conversion : decomposeTypeConversions) - if (conversion(type, types)) - return; - types.push_back(type); -} - void mlir::populateBufferizeMaterializationLegality(ConversionTarget &target) { target.addLegalOp(); } @@ -105,113 +83,3 @@ patterns.insert( typeConverter, context); } - -//===----------------------------------------------------------------------===// -// BufferizeFuncOpConverter -//===----------------------------------------------------------------------===// - -/// Performs the actual function signature rewriting step. -LogicalResult BufferizeFuncOpConverter::matchAndRewrite( - mlir::FuncOp funcOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - auto funcType = funcOp.getType(); - - // Convert function arguments using the provided TypeConverter. - TypeConverter::SignatureConversion conversion(funcType.getNumInputs()); - for (auto argType : llvm::enumerate(funcType.getInputs())) { - SmallVector decomposedTypes, convertedTypes; - converter.tryDecomposeType(argType.value(), decomposedTypes); - converter.convertTypes(decomposedTypes, convertedTypes); - conversion.addInputs(argType.index(), convertedTypes); - } - - // Convert the result types of the function. - SmallVector newResultTypes; - newResultTypes.reserve(funcOp.getNumResults()); - for (Type resultType : funcType.getResults()) { - SmallVector originTypes; - converter.tryDecomposeType(resultType, originTypes); - for (auto origin : originTypes) - newResultTypes.push_back(converter.convertType(origin)); - } - - if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), converter, - &conversion))) - return failure(); - - // Update the signature of the function. - rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), - newResultTypes)); - }); - return success(); -} - -//===----------------------------------------------------------------------===// -// BufferizeCallOpConverter -//===----------------------------------------------------------------------===// - -/// Performs the actual rewriting step. -LogicalResult BufferizeCallOpConverter::matchAndRewrite( - CallOp callOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - - Location loc = callOp.getLoc(); - SmallVector newOperands; - - // TODO: if the CallOp references a FuncOp that only has a declaration (e.g. - // to an externally defined symbol like an external library calls), only - // convert if some special attribute is set. - // This will allow more control of interop across ABI boundaries. - - // Create the operands list of the new `CallOp`. It unpacks the decomposable - // values if a decompose callback function has been provided by the user. - for (auto operand : operands) - converter.tryDecomposeValue(rewriter, loc, operand.getType(), operand, - newOperands); - - // Create the new result types for the new `CallOp` and track the indices in - // the new call op's results that correspond to the old call op's results. - SmallVector newResultTypes; - SmallVector, 4> expandedResultIndices; - expandedResultIndices.resize(callOp.getNumResults()); - for (auto result : llvm::enumerate(callOp.getResults())) { - SmallVector originTypes; - converter.tryDecomposeType(result.value().getType(), originTypes); - auto &resultMapping = expandedResultIndices[result.index()]; - for (Type origin : originTypes) { - Type converted = converter.convertType(origin); - newResultTypes.push_back(converted); - // The result value is not yet available. Its index is kept and it is - // replaced with the actual value of the new `CallOp` later. - resultMapping.push_back(newResultTypes.size() - 1); - } - } - - CallOp newCallOp = rewriter.create(loc, callOp.getCallee(), - newResultTypes, newOperands); - - // Build a replacing value for each result to replace its uses. If a result - // has multiple mapping values, it needs to be packed to a single value. - SmallVector replacedValues; - replacedValues.reserve(callOp.getNumResults()); - for (unsigned i = 0, e = callOp.getNumResults(); i < e; ++i) { - auto valuesToPack = llvm::to_vector<6>( - llvm::map_range(expandedResultIndices[i], - [&](int i) { return newCallOp.getResult(i); })); - if (valuesToPack.empty()) { - // No replacement is required. - replacedValues.push_back(nullptr); - } else if (valuesToPack.size() == 1) { - replacedValues.push_back(valuesToPack.front()); - } else { - // Values need to be packed using callback function. The same callback - // that is used for materializeArgumentConversion is used for packing. - Value packed = converter.materializeArgumentConversion( - rewriter, loc, callOp.getType(i), valuesToPack); - replacedValues.push_back(packed); - } - } - rewriter.replaceOp(callOp, replacedValues); - return success(); -} diff --git a/mlir/test/Transforms/decompose-call-graph-types.mlir b/mlir/test/Transforms/decompose-call-graph-types.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/decompose-call-graph-types.mlir @@ -0,0 +1,105 @@ +// RUN: mlir-opt %s -split-input-file -test-decompose-call-graph-types | FileCheck %s + +// Test case: Most basic case of a 1:N decomposition, an identity function. + +// CHECK-LABEL: func @identity( +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { +// CHECK: %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple +// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 0 : i32} : (tuple) -> i1 +// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 1 : i32} : (tuple) -> i32 +// CHECK: return %[[RET0]], %[[RET1]] : i1, i32 +func @identity(%arg0: tuple) -> tuple { + return %arg0 : tuple +} + +// ----- + +// Test case: Ensure no materializations in the case of 1:1 decomposition. + +// CHECK-LABEL: func @identity_1_to_1_no_materializations( +// CHECK-SAME: %[[ARG0:.*]]: i1) -> i1 { +// CHECK: return %[[ARG0]] : i1 +func @identity_1_to_1_no_materializations(%arg0: tuple) -> tuple { + return %arg0 : tuple +} + +// ----- + +// Test case: Type that needs to be recursively decomposed. + +// CHECK-LABEL: func @recursive_decomposition( +// CHECK-SAME: %[[ARG0:.*]]: i1) -> i1 { +// CHECK: return %[[ARG0]] : i1 +func @recursive_decomposition(%arg0: tuple>>) -> tuple>> { + return %arg0 : tuple>> +} + +// ----- + +// Test case: Check decomposition of calls. + +// CHECK-LABEL: func @callee(i1, i32) -> (i1, i32) +func @callee(tuple) -> tuple + +// CHECK-LABEL: func @caller( +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { +// CHECK: %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple +// CHECK: %[[CALL_ARG0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 0 : i32} : (tuple) -> i1 +// CHECK: %[[CALL_ARG1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 1 : i32} : (tuple) -> i32 +// CHECK: %[[DECOMPOSED:.*]]:2 = call @callee(%[[CALL_ARG0]], %[[CALL_ARG1]]) : (i1, i32) -> (i1, i32) +// CHECK: %[[CALL_RESULT_RECOMPOSED:.*]] = "test.make_tuple"(%[[DECOMPOSED]]#0, %[[DECOMPOSED]]#1) : (i1, i32) -> tuple +// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 0 : i32} : (tuple) -> i1 +// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 1 : i32} : (tuple) -> i32 +// CHECK: return %[[RET0]], %[[RET1]] : i1, i32 +func @caller(%arg0: tuple) -> tuple { + %0 = call @callee(%arg0) : (tuple) -> tuple + return %0 : tuple +} + +// ----- + +// Test case: Type that decomposes to nothing (that is, a 1:0 decomposition). + +// CHECK-LABEL: func @callee() +func @callee(tuple<>) -> tuple<> +// CHECK-LABEL: func @caller() { +// CHECK: call @callee() : () -> () +// CHECK: return +func @caller(%arg0: tuple<>) -> tuple<> { + %0 = call @callee(%arg0) : (tuple<>) -> (tuple<>) + return %0 : tuple<> +} + +// ----- + +// Test case: Ensure decompositions are inserted properly around results of +// unconverted ops. + +// CHECK-LABEL: func @unconverted_op_result() -> (i1, i32) { +// CHECK: %[[UNCONVERTED_VALUE:.*]] = "test.source"() : () -> tuple +// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 0 : i32} : (tuple) -> i1 +// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 1 : i32} : (tuple) -> i32 +// CHECK: return %[[RET0]], %[[RET1]] : i1, i32 +func @unconverted_op_result() -> tuple { + %0 = "test.source"() : () -> (tuple) + return %0 : tuple +} + +// ----- + +// Test case: Check mixed decomposed and non-decomposed args. + +// CHECK-LABEL: func @callee(i32, i64) -> (f32, f64) +func @callee(tuple<>, i32, tuple<>, i64) -> (tuple<>, f32, tuple<>, f64) + +// CHECK-LABEL: func @caller( +// CHECK-SAME: %[[ARG0:.*]]: i32, +// CHECK-SAME: %[[ARG1:.*]]: i64) -> (f32, f64) { +// CHECK: %[[RESULTS:.*]]:2 = call @callee(%[[ARG0]], %[[ARG1]]) : (i32, i64) -> (f32, f64) +// CHECK: return %[[RESULTS]]#0, %[[RESULTS]]#1 : f32, f64 +func @caller(%arg0: tuple<>, %arg1: i32, %arg2: tuple<>, %arg3: i64) -> (tuple<>, f32, tuple<>, f64) { + %0, %1, %2, %3 = call @callee(%arg0, %arg1, %arg2, %arg3) : (tuple<>, i32, tuple<>, i64) -> (tuple<>, f32, tuple<>, f64) + return %0, %1, %2, %3 : tuple<>, f32, tuple<>, f64 +} diff --git a/mlir/test/Transforms/finalizing-bufferize.mlir b/mlir/test/Transforms/finalizing-bufferize.mlir deleted file mode 100644 --- a/mlir/test/Transforms/finalizing-bufferize.mlir +++ /dev/null @@ -1,180 +0,0 @@ -// RUN: mlir-opt -test-finalizing-bufferize -split-input-file %s | FileCheck %s - -// CHECK-LABEL: func @void_function_signature_conversion -func @void_function_signature_conversion(%arg0: tensor<4x8xf32>) { - return -} -// CHECK: ({{.*}}: memref<4x8xf32>) - -// ----- - -// CHECK-LABEL: func @complex_signature_conversion -func @complex_signature_conversion( - %arg0: tensor<5xf32>, - %arg1: memref<10xf32>, - %arg2: i1, - %arg3: f16) -> ( - i1, - tensor<5xf32>, - memref<10xf32>, - memref<15xf32>, - f16) { - %0 = alloc() : memref<15xf32> - %1 = test.tensor_based in(%arg0 : tensor<5xf32>) -> tensor<5xf32> - return %arg2, %1, %arg1, %0, %arg3 : - i1, tensor<5xf32>, memref<10xf32>, memref<15xf32>, f16 -} -// CHECK: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>, -// CHECK-SAME: %[[ARG2:.*]]: i1, %[[ARG3:.*]]: f16) -// CHECK-SAME: (i1, memref<5xf32>, memref<10xf32>, memref<15xf32>, f16) -// CHECK: %[[FIRST_ALLOC:.*]] = alloc() -// CHECK: %[[TENSOR_ALLOC:.*]] = alloc() -// CHECK: return %[[ARG2]], %[[TENSOR_ALLOC]], %[[ARG1]], %[[FIRST_ALLOC]], -// CHECK-SAME: %[[ARG3]] - -// ----- - -// CHECK-LABEL: func @no_signature_conversion_is_needed -func @no_signature_conversion_is_needed(%arg0: memref<4x8xf32>) { - return -} -// CHECK: ({{.*}}: memref<4x8xf32>) - -// ----- - -// CHECK-LABEL: func @no_signature_conversion_is_needed -func @no_signature_conversion_is_needed(%arg0: i1, %arg1: f16) -> (i1, f16){ - return %arg0, %arg1 : i1, f16 -} -// CHECK: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: f16) -> (i1, f16) -// CHECK: return %[[ARG0]], %[[ARG1]] - -// ----- - -// CHECK-LABEL: func @simple_signature_conversion -func @simple_signature_conversion(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { - return %arg0 : tensor<4x8xf32> -} -// CHECK: (%[[ARG0:.*]]: [[TYPE:.*]]<[[RANK:.*]]>) -> [[TYPE]]<[[RANK]]> -// CHECK-NEXT: return %[[ARG0]] - -// ----- - -// CHECK-LABEL: func @func_with_unranked_arg_and_result -func @func_with_unranked_arg_and_result(%arg0: tensor<*xf32>) -> tensor<*xf32> { - return %arg0 : tensor<*xf32> -} -// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) -> memref<*xf32> -// CHECK-NEXT: return [[ARG]] : memref<*xf32> - -// ----- - -// CHECK-LABEL: func @func_and_block_signature_conversion -func @func_and_block_signature_conversion(%arg0 : tensor<2xf32>, %cond : i1, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32>{ - cond_br %cond, ^bb1, ^bb2 - ^bb1: - br ^exit(%arg0 : tensor<2xf32>) - ^bb2: - br ^exit(%arg0 : tensor<2xf32>) - ^exit(%arg2: tensor<2xf32>): - return %arg1 : tensor<4x4xf32> -} -// CHECK: (%[[ARG0:.*]]: [[ARG0_TYPE:.*]], %[[COND:.*]]: i1, %[[ARG1:.*]]: [[ARG1_TYPE:.*]]) -> [[RESULT_TYPE:.*]] { -// CHECK: br ^[[EXIT_BLOCK:.*]](%[[ARG0]] : [[ARG0_TYPE]]) -// CHECK: br ^[[EXIT_BLOCK]](%[[ARG0]] : [[ARG0_TYPE]]) -// CHECK: ^[[EXIT_BLOCK]](%{{.*}}: [[ARG0_TYPE]]) -// CHECK-NEXT: return %[[ARG1]] : [[RESULT_TYPE]] - -// ----- - -// CHECK-LABEL: func @callee -func @callee(%arg1: tensor<5xf32>) -> (tensor<5xf32>, memref<2xf32>) { - %buff = alloc() : memref<2xf32> - return %arg1, %buff : tensor<5xf32>, memref<2xf32> -} -// CHECK: (%[[CALLEE_ARG:.*]]: memref<5xf32>) -> (memref<5xf32>, memref<2xf32>) -// CHECK: %[[ALLOC:.*]] = alloc() -// CHECK: return %[[CALLEE_ARG]], %[[ALLOC]] - -// CHECK-LABEL: func @caller -func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> { - %x:2 = call @callee(%arg0) : (tensor<5xf32>) -> (tensor<5xf32>, memref<2xf32>) - %y:2 = call @callee(%x#0) : (tensor<5xf32>) -> (tensor<5xf32>, memref<2xf32>) - return %y#0 : tensor<5xf32> -} -// CHECK: (%[[CALLER_ARG:.*]]: memref<5xf32>) -> memref<5xf32> -// CHECK: %[[X:.*]]:2 = call @callee(%[[CALLER_ARG]]) -// CHECK: %[[Y:.*]]:2 = call @callee(%[[X]]#0) -// CHECK: return %[[Y]]#0 - -// ----- - -// Test case: Testing BufferizeCallOpConverter to see if it matches with the -// signature of the new signature of the callee function when there are tuple -// typed args and results. BufferizeTypeConverter is set to flatten tuple typed -// arguments. The tuple typed values should be decomposed and composed using -// get_tuple_element and make_tuple operations of test dialect. Tensor types are -// converted to Memref. Memref typed function results remain as function -// results. - -// CHECK-LABEL: func @callee -func @callee(%arg0: tuple,i1, tensor<5xf32>>) -> (tuple,i1, tensor<5xf32>>){ - return %arg0 : tuple,i1, tensor<5xf32>> -} -// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>) -// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) -// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32} -// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32} -// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32} -// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]] - -// CHECK-LABEL: func @caller -func @caller(%arg0: tuple,i1, tensor<5xf32>>) -> tuple,i1, tensor<5xf32>>{ - %x0 = call @callee(%arg0) : (tuple,i1, tensor<5xf32>>) -> (tuple,i1, tensor<5xf32>>) - %y0 = call @callee(%x0) : (tuple,i1, tensor<5xf32>>) -> (tuple,i1, tensor<5xf32>>) - return %y0 : tuple,i1, tensor<5xf32>> -} -// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>) -// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -// CHECK-NEXT: %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) -// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 0 : i32} -// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 1 : i32} -// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 2 : i32} -// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]) -// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>) -// CHECK-NEXT: %[[RESULT_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2) -// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 0 : i32} -// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 1 : i32} -// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 2 : i32} -// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]) -// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>) -// CHECK-NEXT: %[[RETURN_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2) -// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 0 : i32} -// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 1 : i32} -// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 2 : i32} -// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]] - -// ----- - -// Test case: Testing BufferizeFuncOpConverter and -// BufferizeReturnOpConverter to see if the return operation matches with the -// new function signature when there are tuple typed args and results. -// BufferizeTypeConverter is set to flatten tuple typed arguments. The tuple -// typed values should be decomposed and composed using get_tuple_element and -// make_tuple operations of test dialect. Tensor types are converted to Memref. -// Memref typed function results remain as function results. - -// CHECK-LABEL: func @decompose_tuple_typed_function_args_and_results -func @decompose_tuple_typed_function_args_and_results(%arg0: tuple, %arg1: tensor<10xf32>, %arg2: tuple>) -> (tuple>, tensor<10xf32>, tuple){ - return %arg2, %arg1, %arg0 : tuple>, tensor<10xf32>, tuple -} -// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<10xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: memref<5xf32> -// CHECK-SAME: (i1, memref<5xf32>, memref<10xf32>, i1, f32) -// CHECK-NEXT: %[[FIRST_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) -// CHECK-NEXT: %[[SECOND_TUPLE:.*]] = "test.make_tuple"(%[[ARG3]], %[[ARG4]]) -// CHECK-NEXT: %[[SECOND_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 0 : i32} -// CHECK-NEXT: %[[SECOND_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 1 : i32} -// CHECK-NEXT: %[[FIRST_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 0 : i32} -// CHECK-NEXT: %[[FIRST_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 1 : i32} -// CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[SECOND_TUPLE_SECOND_ELEM]], %[[ARG2]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]] diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ TestAffineLoopParametricTiling.cpp TestExpandTanh.cpp TestCallGraph.cpp + TestDecomposeCallGraphTypes.cpp TestConstantFold.cpp TestConvVectorization.cpp TestConvertCallOp.cpp @@ -10,7 +11,6 @@ TestConvertGPUKernelToHsaco.cpp TestDominance.cpp TestDynamicPipeline.cpp - TestFinalizingBufferize.cpp TestLoopFusion.cpp TestGpuMemoryPromotion.cpp TestGpuParallelLoopMapping.cpp diff --git a/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp @@ -0,0 +1,97 @@ +//===- TestDecomposeCallGraphTypes.cpp - Test CG type decomposition -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { +/// A pass for testing call graph type decomposition. +/// +/// This instantiates the patterns with a TypeConverter and ValueDecomposer +/// that splits tuple types into their respective element types. +/// For example, `tuple --> T1, T2, T3`. +struct TestDecomposeCallGraphTypes + : public PassWrapper> { + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { + auto module = getOperation(); + auto *context = &getContext(); + TypeConverter typeConverter; + ConversionTarget target(*context); + ValueDecomposer decomposer; + OwningRewritePatternList patterns; + + target.addLegalDialect(); + + target.addDynamicallyLegalOp([&](ReturnOp op) { + return typeConverter.isLegal(op.getOperandTypes()); + }); + target.addDynamicallyLegalOp( + [&](CallOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp([&](FuncOp op) { + return typeConverter.isSignatureLegal(op.getType()); + }); + + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addConversion( + [](TupleType tupleType, SmallVectorImpl &types) { + tupleType.getFlattenedTypes(types); + return success(); + }); + + decomposer.addDecomposeValueConversion([](OpBuilder &builder, Location loc, + TupleType resultType, Value value, + SmallVectorImpl &values) { + for (unsigned i = 0, e = resultType.size(); i < e; ++i) { + Value res = builder.create( + loc, resultType.getType(i), value, builder.getI32IntegerAttr(i)); + values.push_back(res); + } + return success(); + }); + + typeConverter.addArgumentMaterialization( + [](OpBuilder &builder, TupleType resultType, ValueRange inputs, + Location loc) -> Optional { + if (inputs.size() == 1) + return llvm::None; + TypeRange TypeRange = inputs.getTypes(); + SmallVector types(TypeRange.begin(), TypeRange.end()); + TupleType tuple = TupleType::get(types, builder.getContext()); + Value value = builder.create(loc, tuple, inputs); + return value; + }); + + populateDecomposeCallGraphTypesPatterns(context, typeConverter, decomposer, + patterns); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // end anonymous namespace + +namespace mlir { +namespace test { +void registerTestDecomposeCallGraphTypes() { + PassRegistration pass( + "test-decompose-call-graph-types", + "Decomposes types at call graph boundaries."); +} +} // namespace test +} // namespace mlir diff --git a/mlir/test/lib/Transforms/TestFinalizingBufferize.cpp b/mlir/test/lib/Transforms/TestFinalizingBufferize.cpp deleted file mode 100644 --- a/mlir/test/lib/Transforms/TestFinalizingBufferize.cpp +++ /dev/null @@ -1,167 +0,0 @@ -//===- TestFinalizingBufferize.cpp - Finalizing bufferization ---*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a pass that exercises the functionality of finalizing -// bufferizations. -// -//===----------------------------------------------------------------------===// - -#include "TestDialect.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/Operation.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/Bufferize.h" - -using namespace mlir; - -namespace { -/// This pass is a test for "finalizing" bufferize conversions. -/// -/// A "finalizing" bufferize conversion is one that performs a "full" conversion -/// and expects all tensors to be gone from the program. This in particular -/// involves rewriting funcs (including block arguments of the contained -/// region), calls, and returns. The unique property of finalizing bufferization -/// passes is that they cannot be done via a local transformation with suitable -/// materializations to ensure composability (as other bufferization passes do). -/// For example, if a call is rewritten, the callee needs to be rewritten -/// otherwise the IR will end up invalid. Thus, finalizing bufferization passes -/// require an atomic change to the entire program (e.g. the whole module). -/// -/// TODO: Split out BufferizeFinalizationPolicy from BufferizeTypeConverter. -struct TestFinalizingBufferizePass - : mlir::PassWrapper> { - - /// Converts tensor based test operations to buffer based ones using - /// bufferize. - class TensorBasedOpConverter - : public BufferizeOpConversionPattern { - public: - using BufferizeOpConversionPattern< - test::TensorBasedOp>::BufferizeOpConversionPattern; - - LogicalResult - matchAndRewrite(test::TensorBasedOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - mlir::test::TensorBasedOpAdaptor adaptor( - operands, op.getOperation()->getAttrDictionary()); - - // The input needs to be turned into a buffer first. Until then, bail out. - if (!adaptor.input().getType().isa()) - return failure(); - - Location loc = op.getLoc(); - - // Update the result type to a memref type. - auto type = op.getResult().getType().cast(); - if (!type.hasStaticShape()) - return rewriter.notifyMatchFailure( - op, "dynamic shapes not currently supported"); - auto memrefType = MemRefType::get(type.getShape(), type.getElementType()); - Value newOutputBuffer = rewriter.create(loc, memrefType); - - // Generate a new test operation that works on buffers. - rewriter.create(loc, - /*input=*/adaptor.input(), - /*output=*/newOutputBuffer); - - // Replace the results of the old op with the new output buffers. - rewriter.replaceOp(op, newOutputBuffer); - return success(); - } - }; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - MLIRContext &context = this->getContext(); - ConversionTarget target(context); - BufferizeTypeConverter converter; - - // Mark all Standard operations legal. - target.addLegalDialect(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - - // Mark all Test operations illegal as long as they work on tensors. - auto isLegalOperation = [&](Operation *op) { - return converter.isLegal(op); - }; - target.addDynamicallyLegalDialect(isLegalOperation); - - // Mark Standard Return operations illegal as long as one operand is tensor. - target.addDynamicallyLegalOp([&](mlir::ReturnOp returnOp) { - return converter.isLegal(returnOp.getOperandTypes()); - }); - - // Mark Standard Call Operation illegal as long as it operates on tensor. - target.addDynamicallyLegalOp( - [&](mlir::CallOp callOp) { return converter.isLegal(callOp); }); - - // Mark the function whose arguments are in tensor-type illegal. - target.addDynamicallyLegalOp([&](FuncOp funcOp) { - return converter.isSignatureLegal(funcOp.getType()) && - converter.isLegal(&funcOp.getBody()); - }); - - converter.addDecomposeTypeConversion( - [](TupleType tupleType, SmallVectorImpl &types) { - tupleType.getFlattenedTypes(types); - return success(); - }); - - converter.addArgumentMaterialization( - [](OpBuilder &builder, TupleType resultType, ValueRange inputs, - Location loc) -> Optional { - if (inputs.size() == 1) - return llvm::None; - TypeRange TypeRange = inputs.getTypes(); - SmallVector types(TypeRange.begin(), TypeRange.end()); - TupleType tuple = TupleType::get(types, builder.getContext()); - mlir::Value value = - builder.create(loc, tuple, inputs); - return value; - }); - - converter.addDecomposeValueConversion([](OpBuilder &builder, Location loc, - TupleType resultType, Value value, - SmallVectorImpl &values) { - for (unsigned i = 0, e = resultType.size(); i < e; ++i) { - Value res = builder.create( - loc, resultType.getType(i), value, builder.getI32IntegerAttr(i)); - values.push_back(res); - } - return success(); - }); - - OwningRewritePatternList patterns; - populateWithBufferizeOpConversionPatterns( - &context, converter, patterns); - patterns.insert(&context, converter); - - if (failed(applyFullConversion(this->getOperation(), target, - std::move(patterns)))) - this->signalPassFailure(); - }; -}; -} // end anonymous namespace - -namespace mlir { -namespace test { -void registerTestFinalizingBufferizePass() { - PassRegistration( - "test-finalizing-bufferize", "Tests finalizing bufferize conversions"); -} -} // namespace test -} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -63,11 +63,11 @@ void registerTestConvVectorization(); void registerTestConvertGPUKernelToCubinPass(); void registerTestConvertGPUKernelToHsacoPass(); +void registerTestDecomposeCallGraphTypes(); void registerTestDialect(DialectRegistry &); void registerTestDominancePass(); void registerTestDynamicPipelinePass(); void registerTestExpandTanhPass(); -void registerTestFinalizingBufferizePass(); void registerTestGpuParallelLoopMappingPass(); void registerTestInterfaces(); void registerTestLinalgCodegenStrategy(); @@ -128,10 +128,10 @@ test::registerTestConvertGPUKernelToHsacoPass(); #endif test::registerTestConvVectorization(); + test::registerTestDecomposeCallGraphTypes(); test::registerTestDominancePass(); test::registerTestDynamicPipelinePass(); test::registerTestExpandTanhPass(); - test::registerTestFinalizingBufferizePass(); test::registerTestGpuParallelLoopMappingPass(); test::registerTestInterfaces(); test::registerTestLinalgCodegenStrategy();