diff --git a/flang/lib/Optimizer/Dialect/FIRDialect.cpp b/flang/lib/Optimizer/Dialect/FIRDialect.cpp --- a/flang/lib/Optimizer/Dialect/FIRDialect.cpp +++ b/flang/lib/Optimizer/Dialect/FIRDialect.cpp @@ -14,9 +14,49 @@ #include "flang/Optimizer/Dialect/FIRAttr.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" +#include "mlir/Transforms/InliningUtils.h" using namespace fir; +namespace { +/// This class defines the interface for handling inlining of FIR calls. +struct FIRInlinerInterface : public mlir::DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + bool isLegalToInline(mlir::Operation *call, mlir::Operation *callable, + bool wouldBeCloned) const final { + return fir::canLegallyInline(call, callable, wouldBeCloned); + } + + /// This hook checks to see if the operation `op` is legal to inline into the + /// given region `reg`. + bool isLegalToInline(mlir::Operation *op, mlir::Region *reg, + bool wouldBeCloned, + mlir::BlockAndValueMapping &map) const final { + return fir::canLegallyInline(op, reg, wouldBeCloned, map); + } + + /// This hook is called when a terminator operation has been inlined. + /// We handle the return (a Fortran FUNCTION) by replacing the values + /// previously returned by the call operation with the operands of the + /// return. + void handleTerminator(mlir::Operation *op, + llvm::ArrayRef valuesToRepl) const final { + auto returnOp = cast(op); + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()].replaceAllUsesWith(it.value()); + } + + mlir::Operation *materializeCallConversion(mlir::OpBuilder &builder, + mlir::Value input, + mlir::Type resultType, + mlir::Location loc) const final { + return builder.create(loc, resultType, input); + } +}; +} // namespace + fir::FIROpsDialect::FIROpsDialect(mlir::MLIRContext *ctx) : mlir::Dialect("fir", ctx, mlir::TypeID::get()) { registerTypes(); @@ -25,6 +65,7 @@ #define GET_OP_LIST #include "flang/Optimizer/Dialect/FIROps.cpp.inc" >(); + addInterfaces(); } // anchor the class vtable to this compilation unit diff --git a/flang/test/Fir/inline.fir b/flang/test/Fir/inline.fir new file mode 100644 --- /dev/null +++ b/flang/test/Fir/inline.fir @@ -0,0 +1,19 @@ +// RUN: tco --target=x86_64-unknown-linux-gnu --inline-all %s -o - | FileCheck %s + +// CHECK-LABEL: @add +func @add(%a : i32, %b : i32) -> i32 { + // CHECK: %[[add:.*]] = add i32 + %p = arith.addi %a, %b : i32 + // CHECK: ret i32 %[[add]] + return %p : i32 +} + +// CHECK-LABEL: @test +func @test(%a : i32, %b : i32, %c : i32) -> i32 { + // CHECK: %[[add:.*]] = add i32 + %m = fir.call @add(%a, %b) : (i32, i32) -> i32 + // CHECK: %[[mul:.*]] = mul i32 %[[add]], + %n = arith.muli %m, %c : i32 + // CHECK: ret i32 %[[mul]] + return %n : i32 +}