diff --git a/flang/lib/Optimizer/Dialect/FIRDialect.cpp b/flang/lib/Optimizer/Dialect/FIRDialect.cpp --- a/flang/lib/Optimizer/Dialect/FIRDialect.cpp +++ b/flang/lib/Optimizer/Dialect/FIRDialect.cpp @@ -46,13 +46,6 @@ for (const auto &it : llvm::enumerate(returnOp.getOperands())) valuesToRepl[it.index()].replaceAllUsesWith(it.value()); } - - mlir::Operation *materializeCallConversion(mlir::OpBuilder &builder, - mlir::Value input, - mlir::Type resultType, - mlir::Location loc) const final { - return builder.create<fir::ConvertOp>(loc, resultType, input); - } }; } // namespace diff --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md --- a/mlir/docs/Tutorials/Toy/Ch-4.md +++ b/mlir/docs/Tutorials/Toy/Ch-4.md @@ -77,7 +77,7 @@ return true; } - /// This hook cheks if the given 'src' region can be inlined into the 'dest' + /// This hook checks if the given 'src' region can be inlined into the 'dest' /// region. The regions here are the bodies of the callable functions. For /// Toy, any function can be inlined, so we simply return true. bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, @@ -85,6 +85,15 @@ return true; } + /// This hook checks if 'sourceType' and 'targetType' are convertible or if + /// an argument or result type mismatch prevents inlining. For Toy, all types + /// are convertible, so we simply return true. + bool isTypeConvertible(Operation *call, Operation *callable, Type sourceType, + Type targetType, DictionaryAttr argOrResAttrs, + bool isResult) const final { + return true; + } + /// This hook is called when a terminator operation has been inlined. The only /// terminator that we have in the Toy dialect is the return /// operation(toy.return). We handle the return by replacing the values @@ -275,22 +284,33 @@ ``` -With a proper cast operation, we can now override the necessary hook on the +With a proper cast operation, we can now override the necessary hooks on the ToyInlinerInterface to insert it for us when necessary: ```c++ struct ToyInlinerInterface : public DialectInlinerInterface { ... - /// Attempts to materialize a conversion for a type mismatch between a call - /// from this dialect, and a callable region. This method should generate an - /// operation that takes 'input' as the only operand, and produces a single - /// result of 'resultType'. If a conversion can not be generated, nullptr - /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value input, - Type resultType, - Location conversionLoc) const final { - return builder.create<CastOp>(conversionLoc, resultType, input); + /// Handle possible type mismatches between the arguments of call and + /// callable. This handler introduces a cast operation if the argument type + /// differs from the target type. + Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, + Value argument, Type targetType, + DictionaryAttr argumentAttrs) const final { + if (argument.getType() == targetType) + return argument; + return builder.create<CastOp>(call->getLoc(), targetType, argument); + } + + /// Handle possible type mismatches between the results of callable and call. + /// This handler introduces a cast operation if the result type differs from + /// the target type. + Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, + Value result, Type targetType, + DictionaryAttr resultAttrs) const final { + if (result.getType() == targetType) + return result; + return builder.create<CastOp>(call->getLoc(), targetType, result); } }; ``` diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -53,6 +53,13 @@ return true; } + /// All argument and result types within toy are convertible. + bool isTypeConvertible(Operation *call, Operation *callable, Type sourceType, + Type targetType, DictionaryAttr argOrResAttrs, + bool isResult) const final { + return true; + } + //===--------------------------------------------------------------------===// // Transformation Hooks //===--------------------------------------------------------------------===// @@ -70,15 +77,26 @@ valuesToRepl[it.index()].replaceAllUsesWith(it.value()); } - /// Attempts to materialize a conversion for a type mismatch between a call - /// from this dialect, and a callable region. This method should generate an - /// operation that takes 'input' as the only operand, and produces a single - /// result of 'resultType'. If a conversion can not be generated, nullptr - /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value input, - Type resultType, - Location conversionLoc) const final { - return builder.create<CastOp>(conversionLoc, resultType, input); + /// Handle possible type mismatches between the arguments of call and + /// callable. This handler introduces a cast operation if the argument type + /// differs from the target type. + Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, + Value argument, Type targetType, + DictionaryAttr argumentAttrs) const final { + if (argument.getType() == targetType) + return argument; + return builder.create<CastOp>(call->getLoc(), targetType, argument); + } + + /// Handle possible type mismatches between the results of callable and call. + /// This handler introduces a cast operation if the result type differs from + /// the target type. + Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, + Value result, Type targetType, + DictionaryAttr resultAttrs) const final { + if (result.getType() == targetType) + return result; + return builder.create<CastOp>(call->getLoc(), targetType, result); } }; diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -53,6 +53,13 @@ return true; } + /// All argument and result types within toy are convertible. + bool isTypeConvertible(Operation *call, Operation *callable, Type sourceType, + Type targetType, DictionaryAttr argOrResAttrs, + bool isResult) const final { + return true; + } + //===--------------------------------------------------------------------===// // Transformation Hooks //===--------------------------------------------------------------------===// @@ -70,15 +77,26 @@ valuesToRepl[it.index()].replaceAllUsesWith(it.value()); } - /// Attempts to materialize a conversion for a type mismatch between a call - /// from this dialect, and a callable region. This method should generate an - /// operation that takes 'input' as the only operand, and produces a single - /// result of 'resultType'. If a conversion can not be generated, nullptr - /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value input, - Type resultType, - Location conversionLoc) const final { - return builder.create<CastOp>(conversionLoc, resultType, input); + /// Handle possible type mismatches between the arguments of call and + /// callable. This handler introduces a cast operation if the argument type + /// differs from the target type. + Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, + Value argument, Type targetType, + DictionaryAttr argumentAttrs) const final { + if (argument.getType() == targetType) + return argument; + return builder.create<CastOp>(call->getLoc(), targetType, argument); + } + + /// Handle possible type mismatches between the results of callable and call. + /// This handler introduces a cast operation if the result type differs from + /// the target type. + Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, + Value result, Type targetType, + DictionaryAttr resultAttrs) const final { + if (result.getType() == targetType) + return result; + return builder.create<CastOp>(call->getLoc(), targetType, result); } }; diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -53,6 +53,13 @@ return true; } + /// All argument and result types within toy are convertible. + bool isTypeConvertible(Operation *call, Operation *callable, Type sourceType, + Type targetType, DictionaryAttr argOrResAttrs, + bool isResult) const final { + return true; + } + //===--------------------------------------------------------------------===// // Transformation Hooks //===--------------------------------------------------------------------===// @@ -70,15 +77,26 @@ valuesToRepl[it.index()].replaceAllUsesWith(it.value()); } - /// Attempts to materialize a conversion for a type mismatch between a call - /// from this dialect, and a callable region. This method should generate an - /// operation that takes 'input' as the only operand, and produces a single - /// result of 'resultType'. If a conversion can not be generated, nullptr - /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value input, - Type resultType, - Location conversionLoc) const final { - return builder.create<CastOp>(conversionLoc, resultType, input); + /// Handle possible type mismatches between the arguments of call and + /// callable. This handler introduces a cast operation if the argument type + /// differs from the target type. + Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, + Value argument, Type targetType, + DictionaryAttr argumentAttrs) const final { + if (argument.getType() == targetType) + return argument; + return builder.create<CastOp>(call->getLoc(), targetType, argument); + } + + /// Handle possible type mismatches between the results of callable and call. + /// This handler introduces a cast operation if the result type differs from + /// the target type. + Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, + Value result, Type targetType, + DictionaryAttr resultAttrs) const final { + if (result.getType() == targetType) + return result; + return builder.create<CastOp>(call->getLoc(), targetType, result); } }; diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -54,6 +54,13 @@ return true; } + /// All argument and result types within toy are convertible. + bool isTypeConvertible(Operation *call, Operation *callable, Type sourceType, + Type targetType, DictionaryAttr argOrResAttrs, + bool isResult) const final { + return true; + } + //===--------------------------------------------------------------------===// // Transformation Hooks //===--------------------------------------------------------------------===// @@ -71,15 +78,26 @@ valuesToRepl[it.index()].replaceAllUsesWith(it.value()); } - /// Attempts to materialize a conversion for a type mismatch between a call - /// from this dialect, and a callable region. This method should generate an - /// operation that takes 'input' as the only operand, and produces a single - /// result of 'resultType'. If a conversion can not be generated, nullptr - /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value input, - Type resultType, - Location conversionLoc) const final { - return builder.create<CastOp>(conversionLoc, resultType, input); + /// Handle possible type mismatches between the arguments of call and + /// callable. This handler introduces a cast operation if the argument type + /// differs from the target type. + Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, + Value argument, Type targetType, + DictionaryAttr argumentAttrs) const final { + if (argument.getType() == targetType) + return argument; + return builder.create<CastOp>(call->getLoc(), targetType, argument); + } + + /// Handle possible type mismatches between the results of callable and call. + /// This handler introduces a cast operation if the result type differs from + /// the target type. + Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, + Value result, Type targetType, + DictionaryAttr resultAttrs) const final { + if (result.getType() == targetType) + return result; + return builder.create<CastOp>(call->getLoc(), targetType, result); } }; diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -85,6 +85,22 @@ return false; } + /// Returns true if `sourceType`, either a call argument type or a callee + /// result type, can be converted to `targetType`, or false otherwise. This + /// hook is only called for arguments and results that require type conversion + /// for inlining. It executes right before the `isLegalToInline` method, which + /// checks if a callable can be inlined in-place of a call. The `call` is + /// registered to the current dialect and implements `CallOpInterface`, while + /// the `callable` implements `CallableOpInterface`. The `argOrResAttrs` + /// dictionary contains the argument or result attributes and is guaranteed to + /// be non-null even if no attributes are present. + virtual bool isTypeConvertible(Operation *call, Operation *callable, + Type sourceType, Type targetType, + DictionaryAttr argOrResAttrs, + bool isResult) const { + return false; + } + /// This hook is invoked on an operation that contains regions. It should /// return true if the analyzer should recurse within the regions of this /// operation when computing legality and cost, false otherwise. The default @@ -121,37 +137,15 @@ "must implement handleTerminator in the case of one inlined block"); } - /// Attempt to materialize a conversion for a type mismatch between a call - /// from this dialect, and a callable region. This method should generate an - /// operation that takes 'input' as the only operand, and produces a single - /// result of 'resultType'. If a conversion can not be generated, nullptr - /// should be returned. For example, this hook may be invoked in the following - /// scenarios: - /// func @foo(i32) -> i32 { ... } - /// - /// // Mismatched input operand - /// ... = foo.call @foo(%input : i16) -> i32 - /// - /// // Mismatched result type. - /// ... = foo.call @foo(%input : i32) -> i16 - /// - /// NOTE: This hook may be invoked before the 'isLegal' checks above. - virtual Operation *materializeCallConversion(OpBuilder &builder, Value input, - Type resultType, - Location conversionLoc) const { - return nullptr; - } - /// Hook to transform the call arguments before using them to replace the /// callee arguments. It returns the transformation result or `argument` /// itself if the hook did not change anything. The type of the returned value /// has to match `targetType`, and the `argumentAttrs` dictionary is non-null - /// even if no attribute is present. The hook is called after converting the - /// callsite argument types using the materializeCallConversion callback, and - /// right before inlining the callee region. Any operations created using the - /// provided `builder` are inserted right before the inlined callee region. - /// Example use cases are the insertion of copies for by value arguments, or - /// integer conversions that require signedness information. + /// even if no attributes are present. The hook is called right before + /// inlining the callee region. Any operations created using the provided + /// `builder` are inserted right before the inlined callee region. Example use + /// cases are the insertion of copies for by value arguments, or integer + /// conversions that require signedness information. virtual Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, Value argument, Type targetType, @@ -163,12 +157,11 @@ /// results. It returns the transformation result or the `result` itself if /// the hook did not change anything. The type of the returned values has to /// match `targetType`, and the `resultAttrs` dictionary is non-null even if - /// no attribute is present. The hook is called right before handling - /// terminators, and obtains the callee result before converting its type - /// using the `materializeCallConversion` callback. Any operations created - /// using the provided `builder` are inserted right after the inlined callee - /// region. Example use cases are the insertion of copies for by value results - /// or integer conversions that require signedness information. + /// no attributes are present. The hook is called right before handling + /// terminators. Any operations created using the provided `builder` are + /// inserted right after the inlined callee region. Example use cases are the + /// insertion of copies for by value results or integer conversions that + /// require signedness information. /// NOTE: This hook is invoked after inlining the `callable` region. virtual Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, Value result, Type targetType, @@ -209,6 +202,10 @@ IRMapping &valueMapping) const; virtual bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, IRMapping &valueMapping) const; + virtual bool isTypeConvertible(Operation *call, Operation *callable, + Type sourceType, Type targetType, + DictionaryAttr argOrResAttrs, + bool isResult) const; virtual bool shouldAnalyzeRecursively(Operation *op) const; //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -81,6 +81,16 @@ return false; } +bool InlinerInterface::isTypeConvertible(Operation *call, Operation *callable, + Type sourceType, Type targetType, + DictionaryAttr argOrResAttrs, + bool isResult) const { + if (auto *handler = getInterfaceFor(callable)) + return handler->isTypeConvertible(call, callable, sourceType, targetType, + argOrResAttrs, isResult); + return false; +} + bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const { auto *handler = getInterfaceFor(op); return handler ? handler->shouldAnalyzeRecursively(op) : true; @@ -161,23 +171,43 @@ // Inline Methods //===----------------------------------------------------------------------===// +/// Converts an array attribute to an array of non-null attribute dictionaries. +/// Returns an array of empty dictionaries if the array attribute is null. +static SmallVector<DictionaryAttr> +getAttributesImpl(CallableOpInterface callable, size_t arraySize, + ArrayAttr arrayAttr) { + if (!arrayAttr) + return SmallVector<DictionaryAttr>{ + arraySize, DictionaryAttr::get(callable.getContext(), {})}; + assert(arraySize == arrayAttr.size()); + return llvm::to_vector(arrayAttr.getAsRange<DictionaryAttr>()); +} + +/// Returns the array of argument attribute dictionaries. +static SmallVector<DictionaryAttr> +getArgumentAttributes(CallableOpInterface callable) { + return getAttributesImpl(callable, + callable.getCallableRegion()->getNumArguments(), + callable.getCallableArgAttrs()); +} + +/// Returns the array of result attribute dictionaries. +static SmallVector<DictionaryAttr> +getResultAttributes(CallableOpInterface callable) { + return getAttributesImpl(callable, callable.getCallableResults().size(), + callable.getCallableResAttrs()); +} + static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder, CallOpInterface call, CallableOpInterface callable, IRMapping &mapper) { - // Unpack the argument attributes if there are any. - SmallVector<DictionaryAttr> argAttrs( - callable.getCallableRegion()->getNumArguments(), - builder.getDictionaryAttr({})); - if (ArrayAttr arrayAttr = callable.getCallableArgAttrs()) { - assert(arrayAttr.size() == argAttrs.size()); - for (auto [idx, attr] : llvm::enumerate(arrayAttr)) - argAttrs[idx] = cast<DictionaryAttr>(attr); - } + // Unpack the argument attributes. + SmallVector<DictionaryAttr> argAttrs = getArgumentAttributes(callable); // Run the argument attribute handler for the given argument and attribute. - for (auto [blockArg, argAttr] : - llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) { + for (auto [blockArg, argAttr] : llvm::zip_equal( + callable.getCallableRegion()->getArguments(), argAttrs)) { Value newArgument = interface.handleArgument(builder, call, callable, mapper.lookup(blockArg), blockArg.getType(), argAttr); @@ -192,31 +222,21 @@ static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder, CallOpInterface call, CallableOpInterface callable, ValueRange results) { - // Unpack the result attributes if there are any. - SmallVector<DictionaryAttr> resAttrs(results.size(), - builder.getDictionaryAttr({})); - if (ArrayAttr arrayAttr = callable.getCallableResAttrs()) { - assert(arrayAttr.size() == resAttrs.size()); - for (auto [idx, attr] : llvm::enumerate(arrayAttr)) - resAttrs[idx] = cast<DictionaryAttr>(attr); - } + // Unpack the result attributes. + SmallVector<DictionaryAttr> resAttrs = getResultAttributes(callable); // Run the result attribute handler for the given result and attribute. SmallVector<DictionaryAttr> resultAttributes; - for (auto [result, resAttr] : llvm::zip(results, resAttrs)) { + for (auto [result, resAttr, callResult] : + llvm::zip_equal(results, resAttrs, call->getResults())) { // Store the original result users before running the handler. DenseSet<Operation *> resultUsers; for (Operation *user : result.getUsers()) resultUsers.insert(user); - // TODO: Use the type of the call result to replace once the hook can be - // used for type conversions. At the moment, all type conversions have to be - // done using materializeCallConversion. - Type targetType = result.getType(); - Value newResult = interface.handleResult(builder, call, callable, result, - targetType, resAttr); - assert(newResult.getType() == targetType && + callResult.getType(), resAttr); + assert(newResult.getType() == callResult.getType() && "expected the handled result type to match the target type"); // Replace the result uses except for the ones introduce by the handler. @@ -407,28 +427,6 @@ shouldCloneInlinedRegion); } -/// Utility function used to generate a cast operation from the given interface, -/// or return nullptr if a cast could not be generated. -static Value materializeConversion(const DialectInlinerInterface *interface, - SmallVectorImpl<Operation *> &castOps, - OpBuilder &castBuilder, Value arg, Type type, - Location conversionLoc) { - if (!interface) - return nullptr; - - // Check to see if the interface for the call can materialize a conversion. - Operation *castOp = interface->materializeCallConversion(castBuilder, arg, - type, conversionLoc); - if (!castOp) - return nullptr; - castOps.push_back(castOp); - - // Ensure that the generated cast is correct. - assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg && - castOp->getNumResults() == 1 && *castOp->result_type_begin() == type); - return castOp->getResult(0); -} - /// This function inlines a given region, 'src', of a callable operation, /// 'callable', into the location defined by the given call operation. This /// function returns failure if inlining is not possible, success otherwise. On @@ -453,69 +451,41 @@ callResults.size() != callableResultTypes.size()) return failure(); - // A set of cast operations generated to matchup the signature of the region - // with the signature of the call. - SmallVector<Operation *, 4> castOps; - castOps.reserve(callOperands.size() + callResults.size()); - - // Functor used to cleanup generated state on failure. - auto cleanupState = [&] { - for (auto *op : castOps) { - op->getResult(0).replaceAllUsesWith(op->getOperand(0)); - op->erase(); - } - return failure(); - }; - - // Builder used for any conversion operations that need to be materialized. - OpBuilder castBuilder(call); - Location castLoc = call.getLoc(); - const auto *callInterface = interface.getInterfaceFor(call->getDialect()); - - // Map the provided call operands to the arguments of the region. - IRMapping mapper; - for (unsigned i = 0, e = callOperands.size(); i != e; ++i) { - BlockArgument regionArg = entryBlock->getArgument(i); - Value operand = callOperands[i]; - - // If the call operand doesn't match the expected region argument, try to - // generate a cast. - Type regionArgType = regionArg.getType(); - if (operand.getType() != regionArgType) { - if (!(operand = materializeConversion(callInterface, castOps, castBuilder, - operand, regionArgType, castLoc))) - return cleanupState(); - } - mapper.map(regionArg, operand); + // Check that argument types are convertible. + SmallVector<DictionaryAttr> argAttrs = getArgumentAttributes(callable); + for (auto [argumentType, targetType, argumentAttrs] : + llvm::zip_equal(call.getArgOperands().getTypes(), + entryBlock->getArgumentTypes(), argAttrs)) { + if (argumentType == targetType) + continue; + if (!interface.isTypeConvertible(call, callable, argumentType, targetType, + argumentAttrs, false)) + return failure(); } - // Ensure that the resultant values of the call match the callable. - castBuilder.setInsertionPointAfter(call); - for (unsigned i = 0, e = callResults.size(); i != e; ++i) { - Value callResult = callResults[i]; - if (callResult.getType() == callableResultTypes[i]) + // Check that result types are convertible. + SmallVector<DictionaryAttr> resAttrs = getResultAttributes(callable); + for (auto [resultType, targetType, resultAttrs] : + llvm::zip_equal(callableResultTypes, call->getResultTypes(), resAttrs)) { + if (resultType == targetType) continue; - - // Generate a conversion that will produce the original type, so that the IR - // is still valid after the original call gets replaced. - Value castResult = - materializeConversion(callInterface, castOps, castBuilder, callResult, - callResult.getType(), castLoc); - if (!castResult) - return cleanupState(); - callResult.replaceAllUsesWith(castResult); - castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult); + if (!interface.isTypeConvertible(call, callable, resultType, targetType, + resultAttrs, true)) + return failure(); } // Check that it is legal to inline the callable into the call. if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion)) - return cleanupState(); + return failure(); + + IRMapping mapper; + for (auto [blockArg, operand] : + llvm::zip(entryBlock->getArguments(), callOperands)) + mapper.map(blockArg, operand); // Attempt to inline the call. - if (failed(inlineRegionImpl(interface, src, call->getBlock(), - ++call->getIterator(), mapper, callResults, - callableResultTypes, call.getLoc(), - shouldCloneInlinedRegion, call))) - return cleanupState(); - return success(); + return inlineRegionImpl(interface, src, call->getBlock(), + ++call->getIterator(), mapper, callResults, + callableResultTypes, call.getLoc(), + shouldCloneInlinedRegion, call); } diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir --- a/mlir/test/Transforms/inlining.mlir +++ b/mlir/test/Transforms/inlining.mlir @@ -110,15 +110,15 @@ } // Check that we can convert types for inputs and results as necessary. -func.func @convert_callee_fn(%arg : i32) -> i32 { - return %arg : i32 +test.conversion_func_op @convert_callee_fn(%arg : i32) -> i32 { + "test.return"(%arg) : (i32) -> () } -func.func @convert_callee_fn_multi_arg(%a : i32, %b : i32) -> () { - return +test.conversion_func_op @convert_callee_fn_multi_arg(%a : i32, %b : i32) -> () { + "test.return"() : () -> () } -func.func @convert_callee_fn_multi_res() -> (i32, i32) { +test.conversion_func_op @convert_callee_fn_multi_res() -> (i32, i32) { %res = arith.constant 0 : i32 - return %res, %res : i32, i32 + "test.return"(%res, %res) : (i32, i32) -> () } // CHECK-LABEL: func @inline_convert_call @@ -133,24 +133,24 @@ return %res : i16 } -func.func @convert_callee_fn_multiblock() -> i32 { - cf.br ^bb0 +test.conversion_func_op @convert_callee_fn_multiblock(%arg0 : i32) -> i32 { + "test.br"()[^bb0] : () -> () ^bb0: - %0 = arith.constant 0 : i32 - return %0 : i32 + "test.return"(%arg0) : (i32) -> () } // CHECK-LABEL: func @inline_convert_result_multiblock func.func @inline_convert_result_multiblock() -> i16 { -// CHECK: cf.br ^bb1 {inlined_conversion} -// CHECK: ^bb1: -// CHECK: %[[C:.+]] = arith.constant {inlined_conversion} 0 : i32 -// CHECK: cf.br ^bb2(%[[C]] : i32) -// CHECK: ^bb2(%[[BBARG:.+]]: i32): -// CHECK: %[[CAST_RESULT:.+]] = "test.cast"(%[[BBARG]]) : (i32) -> i16 -// CHECK: return %[[CAST_RESULT]] : i16 - - %res = "test.conversion_call_op"() { callee=@convert_callee_fn_multiblock } : () -> (i16) + // CHECK: %[[INPUT:.+]] = arith.constant + // CHECK: %[[CAST_INPUT:.+]] = "test.cast"(%[[INPUT]]) : (i16) -> i32 + // CHECK: "test.br"()[^bb1] {inlined_conversion} : () -> () + // CHECK: ^bb1: + // CHECK: "test.br"(%[[CAST_INPUT]])[^bb2] : (i32) -> () + // CHECK: ^bb2(%[[BBARG:.+]]: i32): + // CHECK: %[[CAST_RESULT:.+]] = "test.cast"(%[[BBARG]]) : (i32) -> i16 + // CHECK: return %[[CAST_RESULT]] : i16 + %test_input = arith.constant 0 : i16 + %res = "test.conversion_call_op"(%test_input) { callee=@convert_callee_fn_multiblock } : (i16) -> (i16) return %res : i16 } @@ -233,9 +233,6 @@ %1 = arith.subi %arg0, %arg1 : i16 "test.return"(%0, %1) : (i16, i16) -> () } -test.conversion_func_op @handle_attr_callee_fn(%arg0 : i32 {"test.handle_argument"}) -> (i32 {"test.handle_result"}) { - "test.return"(%arg0) : (i32) -> () -} // CHECK-LABEL: func @inline_handle_attr_call // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] @@ -250,16 +247,3 @@ %res0, %res1 = "test.conversion_call_op"(%arg0, %arg1) { callee=@handle_attr_callee_fn_multi_arg } : (i16, i16) -> (i16, i16) return %res0, %res1 : i16, i16 } - -// CHECK-LABEL: func @inline_convert_and_handle_attr_call -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -func.func @inline_convert_and_handle_attr_call(%arg0 : i16) -> (i16) { - - // CHECK: %[[CAST_INPUT:.*]] = "test.cast"(%[[ARG0]]) : (i16) -> i32 - // CHECK: %[[CHANGE_INPUT:.*]] = "test.type_changer"(%[[CAST_INPUT]]) : (i32) -> i32 - // CHECK: %[[CHANGE_RESULT:.*]] = "test.type_changer"(%[[CHANGE_INPUT]]) : (i32) -> i32 - // CHECK: %[[CAST_RESULT:.*]] = "test.cast"(%[[CHANGE_RESULT]]) : (i32) -> i16 - // CHECK: return %[[CAST_RESULT]] - %res = "test.conversion_call_op"(%arg0) { callee=@handle_attr_callee_fn } : (i16) -> (i16) - return %res : i16 -} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -305,6 +305,14 @@ // Don't allow inlining calls that are marked `noinline`. return !call->hasAttr("noinline"); } + bool isTypeConvertible(Operation *call, Operation *callable, Type sourceType, + Type targetType, DictionaryAttr argOrResAttrs, + bool isResult) const final { + return (sourceType.isSignlessInteger(16) || + sourceType.isSignlessInteger(32)) && + (targetType.isSignlessInteger(16) || + targetType.isSignlessInteger(32)); + } bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { // Inlining into test dialect regions is legal. return true; @@ -325,6 +333,20 @@ /// Handle the given inlined terminator by replacing it with a new operation /// as necessary. + void handleTerminator(Operation *op, Block *newDest) const final { + // Only handle "test.return" here. + auto returnOp = dyn_cast<TestReturnOp>(op); + if (!returnOp) + return; + + // Replace the return with a branch to the dest. + OpBuilder builder(op); + builder.create<TestBranchOp>(op->getLoc(), op->getOperands(), newDest); + op->erase(); + } + + /// Handle the given inlined terminator by replacing the values to replace by + /// the terminator operands. void handleTerminator(Operation *op, ArrayRef<Value> valuesToRepl) const final { // Only handle "test.return" here. @@ -338,39 +360,32 @@ valuesToRepl[it.index()].replaceAllUsesWith(it.value()); } - /// Attempt to materialize a conversion for a type mismatch between a call - /// from this dialect, and a callable region. This method should generate an - /// operation that takes 'input' as the only operand, and produces a single - /// result of 'resultType'. If a conversion can not be generated, nullptr - /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value input, - Type resultType, - Location conversionLoc) const final { - // Only allow conversion for i16/i32 types. - if (!(resultType.isSignlessInteger(16) || - resultType.isSignlessInteger(32)) || - !(input.getType().isSignlessInteger(16) || - input.getType().isSignlessInteger(32))) - return nullptr; - return builder.create<TestCastOp>(conversionLoc, resultType, input); - } - Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, Value argument, Type targetType, DictionaryAttr argumentAttrs) const final { - if (!argumentAttrs.contains("test.handle_argument")) + // Cast if an attribute is present. + if (argumentAttrs.contains("test.handle_argument")) + return builder.create<TestTypeChangerOp>(call->getLoc(), targetType, + argument); + + // Cast only if the types do not match. + if (argument.getType() == targetType) return argument; - return builder.create<TestTypeChangerOp>(call->getLoc(), targetType, - argument); + return builder.create<TestCastOp>(call->getLoc(), targetType, argument); } Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, Value result, Type targetType, DictionaryAttr resultAttrs) const final { - if (!resultAttrs.contains("test.handle_result")) + // Cast if an attribute is present. + if (resultAttrs.contains("test.handle_result")) + return builder.create<TestTypeChangerOp>(call->getLoc(), targetType, + result); + + // Cast only if the types do not match. + if (result.getType() == targetType) return result; - return builder.create<TestTypeChangerOp>(call->getLoc(), targetType, - result); + return builder.create<TestCastOp>(call->getLoc(), targetType, result); } void processInlinedCallBlocks(