diff --git a/flang/lib/Optimizer/Dialect/FIRDialect.cpp b/flang/lib/Optimizer/Dialect/FIRDialect.cpp --- a/flang/lib/Optimizer/Dialect/FIRDialect.cpp +++ b/flang/lib/Optimizer/Dialect/FIRDialect.cpp @@ -35,6 +35,14 @@ return fir::canLegallyInline(op, reg, wouldBeCloned, map); } + /// This hook checks if `sourceType` is convertible to `targetType`. + bool isTypeConvertible(mlir::Operation *call, mlir::Operation *callable, + mlir::Type sourceType, mlir::Type targetType, + mlir::DictionaryAttr argOrResAttrs, + bool isResult) const final { + return true; + } + /// This hook is called when a terminator operation has been inlined. /// We handle the return (a Fortran FUNCTION) by replacing the values /// previously returned by the call operation with the operands of the @@ -47,11 +55,24 @@ valuesToRepl[it.index()].replaceAllUsesWith(it.value()); } - mlir::Operation *materializeCallConversion(mlir::OpBuilder &builder, - mlir::Value input, - mlir::Type resultType, - mlir::Location loc) const final { - return builder.create(loc, resultType, input); + /// This hook converts the `argument` to `targetType` if needed. + mlir::Value handleArgument(mlir::OpBuilder &builder, mlir::Operation *call, + mlir::Operation *callable, mlir::Value argument, + mlir::Type targetType, + mlir::DictionaryAttr argumentAttrs) const final { + // The convert operation folds if the argument type matches the target type. + return builder.createOrFold(call->getLoc(), targetType, + argument); + } + + /// This hook converts the `result` to `targetType` if needed. + mlir::Value handleResult(mlir::OpBuilder &builder, mlir::Operation *call, + mlir::Operation *callable, mlir::Value result, + mlir::Type targetType, + mlir::DictionaryAttr resultAttrs) const final { + // The convert operation folds if the result type matches the target type. + return builder.createOrFold(call->getLoc(), targetType, + result); } }; } // namespace diff --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md --- a/mlir/docs/Tutorials/Toy/Ch-4.md +++ b/mlir/docs/Tutorials/Toy/Ch-4.md @@ -77,7 +77,7 @@ return true; } - /// This hook cheks if the given 'src' region can be inlined into the 'dest' + /// This hook checks if the given 'src' region can be inlined into the 'dest' /// region. The regions here are the bodies of the callable functions. For /// Toy, any function can be inlined, so we simply return true. bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, @@ -85,6 +85,15 @@ return true; } + /// This hook checks if 'sourceType' and 'targetType' are convertible or if + /// an argument or result type mismatch prevents inlining. For Toy, all types + /// are convertible, so we simply return true. + bool isTypeConvertible(Operation *call, Operation *callable, Type sourceType, + Type targetType, DictionaryAttr argOrResAttrs, + bool isResult) const final { + return true; + } + /// This hook is called when a terminator operation has been inlined. The only /// terminator that we have in the Toy dialect is the return /// operation(toy.return). We handle the return by replacing the values @@ -263,22 +272,33 @@ ``` -With a proper cast operation, we can now override the necessary hook on the +With a proper cast operation, we can now override the necessary hooks on the ToyInlinerInterface to insert it for us when necessary: ```c++ struct ToyInlinerInterface : public DialectInlinerInterface { ... - /// Attempts to materialize a conversion for a type mismatch between a call - /// from this dialect, and a callable region. This method should generate an - /// operation that takes 'input' as the only operand, and produces a single - /// result of 'resultType'. If a conversion can not be generated, nullptr - /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value input, - Type resultType, - Location conversionLoc) const final { - return builder.create(conversionLoc, resultType, input); + /// Handle possible type mismatches between the arguments of call and + /// callable. This handler introduces a cast operation if the argument type + /// differs from the target type. + Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, + Value argument, Type targetType, + DictionaryAttr argumentAttrs) const final { + if (argument.getType() == targetType) + return argument; + return builder.create(call->getLoc(), targetType, argument); + } + + /// Handle possible type mismatches between the results of callable and call. + /// This handler introduces a cast operation if the result type differs from + /// the target type. + Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, + Value result, Type targetType, + DictionaryAttr resultAttrs) const final { + if (result.getType() == targetType) + return result; + return builder.create(call->getLoc(), targetType, result); } }; ``` diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -53,6 +53,13 @@ return true; } + /// All argument and result types within toy are convertible. + bool isTypeConvertible(Operation *call, Operation *callable, Type sourceType, + Type targetType, DictionaryAttr argOrResAttrs, + bool isResult) const final { + return true; + } + //===--------------------------------------------------------------------===// // Transformation Hooks //===--------------------------------------------------------------------===// @@ -70,15 +77,26 @@ valuesToRepl[it.index()].replaceAllUsesWith(it.value()); } - /// Attempts to materialize a conversion for a type mismatch between a call - /// from this dialect, and a callable region. This method should generate an - /// operation that takes 'input' as the only operand, and produces a single - /// result of 'resultType'. If a conversion can not be generated, nullptr - /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value input, - Type resultType, - Location conversionLoc) const final { - return builder.create(conversionLoc, resultType, input); + /// Handle possible type mismatches between the arguments of call and + /// callable. This handler introduces a cast operation if the argument type + /// differs from the target type. + Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, + Value argument, Type targetType, + DictionaryAttr argumentAttrs) const final { + if (argument.getType() == targetType) + return argument; + return builder.create(call->getLoc(), targetType, argument); + } + + /// Handle possible type mismatches between the results of callable and call. + /// This handler introduces a cast operation if the result type differs from + /// the target type. + Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, + Value result, Type targetType, + DictionaryAttr resultAttrs) const final { + if (result.getType() == targetType) + return result; + return builder.create(call->getLoc(), targetType, result); } }; diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -53,6 +53,13 @@ return true; } + /// All argument and result types within toy are convertible. + bool isTypeConvertible(Operation *call, Operation *callable, Type sourceType, + Type targetType, DictionaryAttr argOrResAttrs, + bool isResult) const final { + return true; + } + //===--------------------------------------------------------------------===// // Transformation Hooks //===--------------------------------------------------------------------===// @@ -70,15 +77,26 @@ valuesToRepl[it.index()].replaceAllUsesWith(it.value()); } - /// Attempts to materialize a conversion for a type mismatch between a call - /// from this dialect, and a callable region. This method should generate an - /// operation that takes 'input' as the only operand, and produces a single - /// result of 'resultType'. If a conversion can not be generated, nullptr - /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value input, - Type resultType, - Location conversionLoc) const final { - return builder.create(conversionLoc, resultType, input); + /// Handle possible type mismatches between the arguments of call and + /// callable. This handler introduces a cast operation if the argument type + /// differs from the target type. + Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, + Value argument, Type targetType, + DictionaryAttr argumentAttrs) const final { + if (argument.getType() == targetType) + return argument; + return builder.create(call->getLoc(), targetType, argument); + } + + /// Handle possible type mismatches between the results of callable and call. + /// This handler introduces a cast operation if the result type differs from + /// the target type. + Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, + Value result, Type targetType, + DictionaryAttr resultAttrs) const final { + if (result.getType() == targetType) + return result; + return builder.create(call->getLoc(), targetType, result); } }; diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -53,6 +53,13 @@ return true; } + /// All argument and result types within toy are convertible. + bool isTypeConvertible(Operation *call, Operation *callable, Type sourceType, + Type targetType, DictionaryAttr argOrResAttrs, + bool isResult) const final { + return true; + } + //===--------------------------------------------------------------------===// // Transformation Hooks //===--------------------------------------------------------------------===// @@ -70,15 +77,26 @@ valuesToRepl[it.index()].replaceAllUsesWith(it.value()); } - /// Attempts to materialize a conversion for a type mismatch between a call - /// from this dialect, and a callable region. This method should generate an - /// operation that takes 'input' as the only operand, and produces a single - /// result of 'resultType'. If a conversion can not be generated, nullptr - /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value input, - Type resultType, - Location conversionLoc) const final { - return builder.create(conversionLoc, resultType, input); + /// Handle possible type mismatches between the arguments of call and + /// callable. This handler introduces a cast operation if the argument type + /// differs from the target type. + Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, + Value argument, Type targetType, + DictionaryAttr argumentAttrs) const final { + if (argument.getType() == targetType) + return argument; + return builder.create(call->getLoc(), targetType, argument); + } + + /// Handle possible type mismatches between the results of callable and call. + /// This handler introduces a cast operation if the result type differs from + /// the target type. + Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, + Value result, Type targetType, + DictionaryAttr resultAttrs) const final { + if (result.getType() == targetType) + return result; + return builder.create(call->getLoc(), targetType, result); } }; diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -54,6 +54,13 @@ return true; } + /// All argument and result types within toy are convertible. + bool isTypeConvertible(Operation *call, Operation *callable, Type sourceType, + Type targetType, DictionaryAttr argOrResAttrs, + bool isResult) const final { + return true; + } + //===--------------------------------------------------------------------===// // Transformation Hooks //===--------------------------------------------------------------------===// @@ -71,15 +78,26 @@ valuesToRepl[it.index()].replaceAllUsesWith(it.value()); } - /// Attempts to materialize a conversion for a type mismatch between a call - /// from this dialect, and a callable region. This method should generate an - /// operation that takes 'input' as the only operand, and produces a single - /// result of 'resultType'. If a conversion can not be generated, nullptr - /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value input, - Type resultType, - Location conversionLoc) const final { - return builder.create(conversionLoc, resultType, input); + /// Handle possible type mismatches between the arguments of call and + /// callable. This handler introduces a cast operation if the argument type + /// differs from the target type. + Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, + Value argument, Type targetType, + DictionaryAttr argumentAttrs) const final { + if (argument.getType() == targetType) + return argument; + return builder.create(call->getLoc(), targetType, argument); + } + + /// Handle possible type mismatches between the results of callable and call. + /// This handler introduces a cast operation if the result type differs from + /// the target type. + Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, + Value result, Type targetType, + DictionaryAttr resultAttrs) const final { + if (result.getType() == targetType) + return result; + return builder.create(call->getLoc(), targetType, result); } }; diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -85,6 +85,22 @@ return false; } + /// Returns true if `sourceType`, either a call argument type or a callee + /// result type, can be converted to `targetType`, or false otherwise. This + /// hook is only called for arguments and results that require type conversion + /// for inlining. It executes right before the `isLegalToInline` method, which + /// checks if a callable can be inlined in-place of a call. The `call` is + /// registered to the current dialect and implements `CallOpInterface`, while + /// the `callable` implements `CallableOpInterface`. The `argOrResAttrs` + /// dictionary contains the argument or result attributes and is guaranteed to + /// be non-null even if no attributes are present. + virtual bool isTypeConvertible(Operation *call, Operation *callable, + Type sourceType, Type targetType, + DictionaryAttr argOrResAttrs, + bool isResult) const { + return false; + } + /// This hook is invoked on an operation that contains regions. It should /// return true if the analyzer should recurse within the regions of this /// operation when computing legality and cost, false otherwise. The default @@ -121,37 +137,15 @@ "must implement handleTerminator in the case of one inlined block"); } - /// Attempt to materialize a conversion for a type mismatch between a call - /// from this dialect, and a callable region. This method should generate an - /// operation that takes 'input' as the only operand, and produces a single - /// result of 'resultType'. If a conversion can not be generated, nullptr - /// should be returned. For example, this hook may be invoked in the following - /// scenarios: - /// func @foo(i32) -> i32 { ... } - /// - /// // Mismatched input operand - /// ... = foo.call @foo(%input : i16) -> i32 - /// - /// // Mismatched result type. - /// ... = foo.call @foo(%input : i32) -> i16 - /// - /// NOTE: This hook may be invoked before the 'isLegal' checks above. - virtual Operation *materializeCallConversion(OpBuilder &builder, Value input, - Type resultType, - Location conversionLoc) const { - return nullptr; - } - /// Hook to transform the call arguments before using them to replace the /// callee arguments. It returns the transformation result or `argument` /// itself if the hook did not change anything. The type of the returned value /// has to match `targetType`, and the `argumentAttrs` dictionary is non-null - /// even if no attribute is present. The hook is called after converting the - /// callsite argument types using the materializeCallConversion callback, and - /// right before inlining the callee region. Any operations created using the - /// provided `builder` are inserted right before the inlined callee region. - /// Example use cases are the insertion of copies for by value arguments, or - /// integer conversions that require signedness information. + /// even if no attributes are present. The hook is called right before + /// inlining the callee region. Any operations created using the provided + /// `builder` are inserted right before the inlined callee region. Example use + /// cases are the insertion of copies for by value arguments, or integer + /// conversions that require signedness information. virtual Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, Value argument, Type targetType, @@ -163,12 +157,11 @@ /// results. It returns the transformation result or the `result` itself if /// the hook did not change anything. The type of the returned values has to /// match `targetType`, and the `resultAttrs` dictionary is non-null even if - /// no attribute is present. The hook is called right before handling - /// terminators, and obtains the callee result before converting its type - /// using the `materializeCallConversion` callback. Any operations created - /// using the provided `builder` are inserted right after the inlined callee - /// region. Example use cases are the insertion of copies for by value results - /// or integer conversions that require signedness information. + /// no attributes are present. The hook is called right before handling + /// terminators. Any operations created using the provided `builder` are + /// inserted right after the inlined callee region. Example use cases are the + /// insertion of copies for by value results or integer conversions that + /// require signedness information. /// NOTE: This hook is invoked after inlining the `callable` region. virtual Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, Value result, Type targetType, @@ -209,6 +202,10 @@ IRMapping &valueMapping) const; virtual bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, IRMapping &valueMapping) const; + virtual bool isTypeConvertible(Operation *call, Operation *callable, + Type sourceType, Type targetType, + DictionaryAttr argOrResAttrs, + bool isResult) const; virtual bool shouldAnalyzeRecursively(Operation *op) const; //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -81,6 +81,16 @@ return false; } +bool InlinerInterface::isTypeConvertible(Operation *call, Operation *callable, + Type sourceType, Type targetType, + DictionaryAttr argOrResAttrs, + bool isResult) const { + if (auto *handler = getInterfaceFor(call)) + return handler->isTypeConvertible(call, callable, sourceType, targetType, + argOrResAttrs, isResult); + return false; +} + bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const { auto *handler = getInterfaceFor(op); return handler ? handler->shouldAnalyzeRecursively(op) : true; @@ -161,26 +171,49 @@ // Inline Methods //===----------------------------------------------------------------------===// -static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder, - CallOpInterface call, - CallableOpInterface callable, - IRMapping &mapper) { - if (!call || !callable) - return; - - // Unpack the argument attributes if there are any. +/// Returns the array of argument attribute dictionaries. The attribute +/// dictionaries are non-null even if no attributes are present. +static SmallVector +getArgumentAttributes(CallableOpInterface callable) { SmallVector argAttrs( callable.getCallableRegion()->getNumArguments(), - builder.getDictionaryAttr({})); + DictionaryAttr::get(callable->getContext(), {})); if (ArrayAttr arrayAttr = callable.getCallableArgAttrs()) { assert(arrayAttr.size() == argAttrs.size()); for (auto [idx, attr] : llvm::enumerate(arrayAttr)) argAttrs[idx] = cast(attr); } + return argAttrs; +} + +/// Returns the array of result attribute dictionaries. The attribute +/// dictionaries are non-null even if no attributes are present. +static SmallVector +getResultAttributes(CallableOpInterface callable) { + SmallVector resAttrs( + callable.getCallableResults().size(), + DictionaryAttr::get(callable->getContext(), {})); + if (ArrayAttr arrayAttr = callable.getCallableResAttrs()) { + assert(arrayAttr.size() == resAttrs.size()); + for (auto [idx, attr] : llvm::enumerate(arrayAttr)) + resAttrs[idx] = cast(attr); + } + return resAttrs; +} + +static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder, + CallOpInterface call, + CallableOpInterface callable, + IRMapping &mapper) { + if (!call || !callable) + return; + + // Unpack the argument attributes. + SmallVector argAttrs = getArgumentAttributes(callable); // Run the argument attribute handler for the given argument and attribute. - for (auto [blockArg, argAttr] : - llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) { + for (auto [blockArg, argAttr] : llvm::zip_equal( + callable.getCallableRegion()->getArguments(), argAttrs)) { Value newArgument = interface.handleArgument(builder, call, callable, mapper.lookup(blockArg), blockArg.getType(), argAttr); @@ -198,31 +231,21 @@ if (!call || !callable) return; - // Unpack the result attributes if there are any. - SmallVector resAttrs(results.size(), - builder.getDictionaryAttr({})); - if (ArrayAttr arrayAttr = callable.getCallableResAttrs()) { - assert(arrayAttr.size() == resAttrs.size()); - for (auto [idx, attr] : llvm::enumerate(arrayAttr)) - resAttrs[idx] = cast(attr); - } + // Unpack the result attributes. + SmallVector resAttrs = getResultAttributes(callable); // Run the result attribute handler for the given result and attribute. SmallVector resultAttributes; - for (auto [result, resAttr] : llvm::zip(results, resAttrs)) { + for (auto [result, resAttr, callResult] : + llvm::zip_equal(results, resAttrs, call->getResults())) { // Store the original result users before running the handler. DenseSet resultUsers; for (Operation *user : result.getUsers()) resultUsers.insert(user); - // TODO: Use the type of the call result to replace once the hook can be - // used for type conversions. At the moment, all type conversions have to be - // done using materializeCallConversion. - Type targetType = result.getType(); - Value newResult = interface.handleResult(builder, call, callable, result, - targetType, resAttr); - assert(newResult.getType() == targetType && + callResult.getType(), resAttr); + assert(newResult.getType() == callResult.getType() && "expected the handled result type to match the target type"); // Replace the result uses except for the ones introduce by the handler. @@ -410,28 +433,6 @@ shouldCloneInlinedRegion); } -/// Utility function used to generate a cast operation from the given interface, -/// or return nullptr if a cast could not be generated. -static Value materializeConversion(const DialectInlinerInterface *interface, - SmallVectorImpl &castOps, - OpBuilder &castBuilder, Value arg, Type type, - Location conversionLoc) { - if (!interface) - return nullptr; - - // Check to see if the interface for the call can materialize a conversion. - Operation *castOp = interface->materializeCallConversion(castBuilder, arg, - type, conversionLoc); - if (!castOp) - return nullptr; - castOps.push_back(castOp); - - // Ensure that the generated cast is correct. - assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg && - castOp->getNumResults() == 1 && *castOp->result_type_begin() == type); - return castOp->getResult(0); -} - /// This function inlines a given region, 'src', of a callable operation, /// 'callable', into the location defined by the given call operation. This /// function returns failure if inlining is not possible, success otherwise. On @@ -456,69 +457,41 @@ callResults.size() != callableResultTypes.size()) return failure(); - // A set of cast operations generated to matchup the signature of the region - // with the signature of the call. - SmallVector castOps; - castOps.reserve(callOperands.size() + callResults.size()); - - // Functor used to cleanup generated state on failure. - auto cleanupState = [&] { - for (auto *op : castOps) { - op->getResult(0).replaceAllUsesWith(op->getOperand(0)); - op->erase(); - } - return failure(); - }; - - // Builder used for any conversion operations that need to be materialized. - OpBuilder castBuilder(call); - Location castLoc = call.getLoc(); - const auto *callInterface = interface.getInterfaceFor(call->getDialect()); - - // Map the provided call operands to the arguments of the region. - IRMapping mapper; - for (unsigned i = 0, e = callOperands.size(); i != e; ++i) { - BlockArgument regionArg = entryBlock->getArgument(i); - Value operand = callOperands[i]; - - // If the call operand doesn't match the expected region argument, try to - // generate a cast. - Type regionArgType = regionArg.getType(); - if (operand.getType() != regionArgType) { - if (!(operand = materializeConversion(callInterface, castOps, castBuilder, - operand, regionArgType, castLoc))) - return cleanupState(); - } - mapper.map(regionArg, operand); + // Check that argument types are convertible. + SmallVector argAttrs = getArgumentAttributes(callable); + for (auto [argumentType, targetType, argumentAttrs] : + llvm::zip_equal(call.getArgOperands().getTypes(), + entryBlock->getArgumentTypes(), argAttrs)) { + if (argumentType == targetType) + continue; + if (!interface.isTypeConvertible(call, callable, argumentType, targetType, + argumentAttrs, false)) + return failure(); } - // Ensure that the resultant values of the call match the callable. - castBuilder.setInsertionPointAfter(call); - for (unsigned i = 0, e = callResults.size(); i != e; ++i) { - Value callResult = callResults[i]; - if (callResult.getType() == callableResultTypes[i]) + // Check that result types are convertible. + SmallVector resAttrs = getResultAttributes(callable); + for (auto [resultType, targetType, resultAttrs] : + llvm::zip_equal(callableResultTypes, call->getResultTypes(), resAttrs)) { + if (resultType == targetType) continue; - - // Generate a conversion that will produce the original type, so that the IR - // is still valid after the original call gets replaced. - Value castResult = - materializeConversion(callInterface, castOps, castBuilder, callResult, - callResult.getType(), castLoc); - if (!castResult) - return cleanupState(); - callResult.replaceAllUsesWith(castResult); - castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult); + if (!interface.isTypeConvertible(call, callable, resultType, targetType, + resultAttrs, true)) + return failure(); } // Check that it is legal to inline the callable into the call. if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion)) - return cleanupState(); + return failure(); + + IRMapping mapper; + for (auto [blockArg, operand] : + llvm::zip(entryBlock->getArguments(), callOperands)) + mapper.map(blockArg, operand); // Attempt to inline the call. - if (failed(inlineRegionImpl(interface, src, call->getBlock(), - ++call->getIterator(), mapper, callResults, - callableResultTypes, call.getLoc(), - shouldCloneInlinedRegion, call))) - return cleanupState(); - return success(); + return inlineRegionImpl(interface, src, call->getBlock(), + ++call->getIterator(), mapper, callResults, + callableResultTypes, call.getLoc(), + shouldCloneInlinedRegion, call); } diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir --- a/mlir/test/Transforms/inlining.mlir +++ b/mlir/test/Transforms/inlining.mlir @@ -233,9 +233,6 @@ %1 = arith.subi %arg0, %arg1 : i16 return %0, %1 : i16, i16 } -func.func @handle_attr_callee_fn(%arg0 : i32 {"test.handle_argument"}) -> (i32 {"test.handle_result"}) { - return %arg0 : i32 -} // CHECK-LABEL: func @inline_handle_attr_call // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] @@ -250,16 +247,3 @@ %res0, %res1 = "test.conversion_call_op"(%arg0, %arg1) { callee=@handle_attr_callee_fn_multi_arg } : (i16, i16) -> (i16, i16) return %res0, %res1 : i16, i16 } - -// CHECK-LABEL: func @inline_convert_and_handle_attr_call -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -func.func @inline_convert_and_handle_attr_call(%arg0 : i16) -> (i16) { - - // CHECK: %[[CAST_INPUT:.*]] = "test.cast"(%[[ARG0]]) : (i16) -> i32 - // CHECK: %[[CHANGE_INPUT:.*]] = "test.type_changer"(%[[CAST_INPUT]]) : (i32) -> i32 - // CHECK: %[[CHANGE_RESULT:.*]] = "test.type_changer"(%[[CHANGE_INPUT]]) : (i32) -> i32 - // CHECK: %[[CAST_RESULT:.*]] = "test.cast"(%[[CHANGE_RESULT]]) : (i32) -> i16 - // CHECK: return %[[CAST_RESULT]] - %res = "test.conversion_call_op"(%arg0) { callee=@handle_attr_callee_fn } : (i16) -> (i16) - return %res : i16 -} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -304,6 +304,14 @@ // Don't allow inlining calls that are marked `noinline`. return !call->hasAttr("noinline"); } + bool isTypeConvertible(Operation *call, Operation *callable, Type sourceType, + Type targetType, DictionaryAttr argOrResAttrs, + bool isResult) const final { + return (sourceType.isSignlessInteger(16) || + sourceType.isSignlessInteger(32)) && + (targetType.isSignlessInteger(16) || + targetType.isSignlessInteger(32)); + } bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { // Inlining into test dialect regions is legal. return true; @@ -337,39 +345,31 @@ valuesToRepl[it.index()].replaceAllUsesWith(it.value()); } - /// Attempt to materialize a conversion for a type mismatch between a call - /// from this dialect, and a callable region. This method should generate an - /// operation that takes 'input' as the only operand, and produces a single - /// result of 'resultType'. If a conversion can not be generated, nullptr - /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value input, - Type resultType, - Location conversionLoc) const final { - // Only allow conversion for i16/i32 types. - if (!(resultType.isSignlessInteger(16) || - resultType.isSignlessInteger(32)) || - !(input.getType().isSignlessInteger(16) || - input.getType().isSignlessInteger(32))) - return nullptr; - return builder.create(conversionLoc, resultType, input); - } - Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, Value argument, Type targetType, DictionaryAttr argumentAttrs) const final { - if (!argumentAttrs.contains("test.handle_argument")) + // Cast if an attribute is present. + if (argumentAttrs.contains("test.handle_argument")) + return builder.create(call->getLoc(), targetType, + argument); + + // Cast only if the types do not match. + if (argument.getType() == targetType) return argument; - return builder.create(call->getLoc(), targetType, - argument); + return builder.create(call->getLoc(), targetType, argument); } Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, Value result, Type targetType, DictionaryAttr resultAttrs) const final { - if (!resultAttrs.contains("test.handle_result")) + // Cast if an attribute is present. + if (resultAttrs.contains("test.handle_result")) + return builder.create(call->getLoc(), targetType, + result); + // Cast only if the types do not match. + if (result.getType() == targetType) return result; - return builder.create(call->getLoc(), targetType, - result); + return builder.create(call->getLoc(), targetType, result); } void processInlinedCallBlocks(