diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -209,6 +209,8 @@ newOpers.push_back(callOp.getOperand(0)); dropFront = 1; } + } else { + dropFront = 1; // First operand is the polymorphic object. } // Determine the rewrite function, `wrap`, for the result value. @@ -231,6 +233,7 @@ llvm::SmallVector trailingInTys; llvm::SmallVector trailingOpers; + unsigned passArgShift = 0; for (auto e : llvm::enumerate( llvm::zip(fnTy.getInputs().drop_front(dropFront), callOp.getOperands().drop_front(dropFront)))) { @@ -314,6 +317,10 @@ } }) .Default([&](mlir::Type ty) { + if constexpr (std::is_same_v, fir::DispatchOp>) { + if (callOp.getPassArgPos() && *callOp.getPassArgPos() == index) + passArgShift = newOpers.size() - *callOp.getPassArgPos(); + } newInTys.push_back(ty); newOpers.push_back(oper); }); @@ -338,8 +345,14 @@ else replaceOp(callOp, newCall.getResults()); } else { - // A is fir::DispatchOp - TODO(loc, "dispatch not implemented"); + fir::DispatchOp dispatchOp = rewriter->create( + loc, newResTys, rewriter->getStringAttr(callOp.getMethod()), + callOp.getOperands()[0], newOpers, + rewriter->getI32IntegerAttr(*callOp.getPassArgPos() + passArgShift)); + if (wrap) + replaceOp(callOp, (*wrap)(dispatchOp.getOperation())); + else + replaceOp(callOp, dispatchOp.getResults()); } } diff --git a/flang/test/Fir/target-rewrite-complex.fir b/flang/test/Fir/target-rewrite-complex.fir --- a/flang/test/Fir/target-rewrite-complex.fir +++ b/flang/test/Fir/target-rewrite-complex.fir @@ -122,12 +122,12 @@ func.func private @paramcomplex4(!fir.complex<4>) -> () // Test that we rewrite calls to functions that return or accept complex<4>. -// I32-LABEL: func @callcomplex4() -// X64-LABEL: func @callcomplex4() -// AARCH64-LABEL: func @callcomplex4() -// PPC-LABEL: func @callcomplex4() -// SPARCV9-LABEL: func @callcomplex4() -func.func @callcomplex4() { +// I32-LABEL: func @callcomplex4 +// X64-LABEL: func @callcomplex4 +// AARCH64-LABEL: func @callcomplex4 +// PPC-LABEL: func @callcomplex4 +// SPARCV9-LABEL: func @callcomplex4 +func.func @callcomplex4(%arg0 : !fir.class>) { // I32: [[RES:%[0-9A-Za-z]+]] = fir.call @returncomplex4() : () -> i64 // X64: [[RES:%[0-9A-Za-z]+]] = fir.call @returncomplex4() : () -> !fir.vector<2:!fir.real<4>> @@ -181,6 +181,69 @@ // SPARCV9: [[B:%[0-9A-Za-z]+]] = fir.extract_value [[V]], [1 : i32] : (!fir.complex<4>) -> !fir.real<4> // SPARCV9: fir.call @paramcomplex4([[A]], [[B]]) : (!fir.real<4>, !fir.real<4>) -> () fir.call @paramcomplex4(%1) : (!fir.complex<4>) -> () + + // I32: [[RES:%[0-9A-Za-z]+]] = fir.dispatch "ret_complex"(%{{.*}} : !fir.class>) (%{{.*}} : !fir.class>) -> i64 {pass_arg_pos = 0 : i32} + // X64: [[RES:%[0-9A-Za-z]+]] = fir.dispatch "ret_complex"(%{{.*}} : !fir.class>) (%{{.*}} : !fir.class>) -> !fir.vector<2:!fir.real<4>> {pass_arg_pos = 0 : i32} + // AARCH64: [[RES:%[0-9A-Za-z]+]] = fir.dispatch "ret_complex"(%{{.*}} : !fir.class>) (%{{.*}} : !fir.class>) -> tuple, !fir.real<4>> {pass_arg_pos = 0 : i32} + // PPC: [[RES:%[0-9A-Za-z]+]] = fir.dispatch "ret_complex"(%{{.*}} : !fir.class>) (%{{.*}} : !fir.class>) -> tuple, !fir.real<4>> {pass_arg_pos = 0 : i32} + // SPARCV9: [[RES:%[0-9A-Za-z]+]] = fir.dispatch "ret_complex"(%{{.*}} : !fir.class>) (%{{.*}} : !fir.class>) -> tuple, !fir.real<4>> {pass_arg_pos = 0 : i32} + %2 = fir.dispatch "ret_complex"(%arg0 : !fir.class>) (%arg0 : !fir.class>) -> !fir.complex<4> {pass_arg_pos = 0 : i32} + + // I32: [[ADDRI64:%[0-9A-Za-z]+]] = fir.alloca i64 + // I32: fir.store [[RES]] to [[ADDRI64]] : !fir.ref + // I32: [[ADDRC:%[0-9A-Za-z]+]] = fir.convert [[ADDRI64]] : (!fir.ref) -> !fir.ref> + // I32: [[C:%[0-9A-Za-z]+]] = fir.load [[ADDRC]] : !fir.ref> + // I32: [[ADDRC2:%[0-9A-Za-z]+]] = fir.alloca !fir.complex<4> + // I32: fir.store [[C]] to [[ADDRC2]] : !fir.ref> + // I32: [[T:%[0-9A-Za-z]+]] = fir.convert [[ADDRC2]] : (!fir.ref>) -> !fir.ref, !fir.real<4>>> + // I32: fir.dispatch "with_complex"(%{{.*}} : !fir.class>) (%{{.*}}, [[T]] : !fir.class>, !fir.ref, !fir.real<4>>>) {pass_arg_pos = 0 : i32} + + // X64: [[ADDRV:%[0-9A-Za-z]+]] = fir.alloca !fir.vector<2:!fir.real<4>> + // X64: fir.store [[RES]] to [[ADDRV]] : !fir.ref>> + // X64: [[ADDRC:%[0-9A-Za-z]+]] = fir.convert [[ADDRV]] : (!fir.ref>>) -> !fir.ref> + // X64: [[V:%[0-9A-Za-z]+]] = fir.load [[ADDRC]] : !fir.ref> + // X64: [[ADDRV2:%[0-9A-Za-z]+]] = fir.alloca !fir.vector<2:!fir.real<4>> + // X64: [[ADDRC2:%[0-9A-Za-z]+]] = fir.convert [[ADDRV2]] : (!fir.ref>>) -> !fir.ref> + // X64: fir.store [[V]] to [[ADDRC2]] : !fir.ref> + // X64: [[VRELOADED:%[0-9A-Za-z]+]] = fir.load [[ADDRV2]] : !fir.ref>> + // X64: fir.dispatch "with_complex"(%{{.*}} : !fir.class>) (%{{.*}}, [[VRELOADED]] : !fir.class>, !fir.vector<2:!fir.real<4>>) {pass_arg_pos = 0 : i32} + + // AARCH64: [[ADDRT:%[0-9A-Za-z]+]] = fir.alloca tuple, !fir.real<4>> + // AARCH64: fir.store [[RES]] to [[ADDRT]] : !fir.ref, !fir.real<4>>> + // AARCH64: [[ADDRC:%[0-9A-Za-z]+]] = fir.convert [[ADDRT]] : (!fir.ref, !fir.real<4>>>) -> !fir.ref> + // AARCH64: [[V:%[0-9A-Za-z]+]] = fir.load [[ADDRC]] : !fir.ref> + // AARCH64: [[ADDRARR:%[0-9A-Za-z]+]] = fir.alloca !fir.array<2x!fir.real<4>> + // AARCH64: [[ADDRC2:%[0-9A-Za-z]+]] = fir.convert [[ADDRARR]] : (!fir.ref>>) -> !fir.ref> + // AARCH64: fir.store [[V]] to [[ADDRC2]] : !fir.ref> + // AARCH64: [[ARR:%[0-9A-Za-z]+]] = fir.load [[ADDRARR]] : !fir.ref>> + // AARCH64: fir.dispatch "with_complex"(%{{.*}} : !fir.class>) (%{{.*}}, [[ARR]] : !fir.class>, !fir.array<2x!fir.real<4>>) {pass_arg_pos = 0 : i32} + + // PPC: [[ADDRT:%[0-9A-Za-z]+]] = fir.alloca tuple, !fir.real<4>> + // PPC: fir.store [[RES]] to [[ADDRT]] : !fir.ref, !fir.real<4>>> + // PPC: [[ADDRC:%[0-9A-Za-z]+]] = fir.convert [[ADDRT]] : (!fir.ref, !fir.real<4>>>) -> !fir.ref> + // PPC: [[V:%[0-9A-Za-z]+]] = fir.load [[ADDRC]] : !fir.ref> + // PPC: [[A:%[0-9A-Za-z]+]] = fir.extract_value [[V]], [0 : i32] : (!fir.complex<4>) -> !fir.real<4> + // PPC: [[B:%[0-9A-Za-z]+]] = fir.extract_value [[V]], [1 : i32] : (!fir.complex<4>) -> !fir.real<4> + // PPC: fir.dispatch "with_complex"(%{{.*}} : !fir.class>) (%{{.*}}, [[A]], [[B]] : !fir.class>, !fir.real<4>, !fir.real<4>) {pass_arg_pos = 0 : i32} + + // SPARCV9: [[ADDRT:%[0-9A-Za-z]+]] = fir.alloca tuple, !fir.real<4>> + // SPARCV9: fir.store [[RES]] to [[ADDRT]] : !fir.ref, !fir.real<4>>> + // SPARCV9: [[ADDRC:%[0-9A-Za-z]+]] = fir.convert [[ADDRT]] : (!fir.ref, !fir.real<4>>>) -> !fir.ref> + // SPARCV9: [[V:%[0-9A-Za-z]+]] = fir.load [[ADDRC]] : !fir.ref> + // SPARCV9: [[A:%[0-9A-Za-z]+]] = fir.extract_value [[V]], [0 : i32] : (!fir.complex<4>) -> !fir.real<4> + // SPARCV9: [[B:%[0-9A-Za-z]+]] = fir.extract_value [[V]], [1 : i32] : (!fir.complex<4>) -> !fir.real<4> + // SPARCV9: fir.dispatch "with_complex"(%{{.*}} : !fir.class>) (%{{.*}}, [[A]], [[B]] : !fir.class>, !fir.real<4>, !fir.real<4>) {pass_arg_pos = 0 : i32} + + fir.dispatch "with_complex"(%arg0 : !fir.class>) (%arg0, %2 : !fir.class>, !fir.complex<4>) {pass_arg_pos = 0 : i32} + + + // I32: fir.dispatch "with_complex2"(%{{.*}} : !fir.class>) (%{{.*}}, %{{.*}} : !fir.ref, !fir.real<4>>>, !fir.class>) {pass_arg_pos = 1 : i32} + // X64: fir.dispatch "with_complex2"(%{{.*}} : !fir.class>) (%{{.*}}, %{{.*}} : !fir.vector<2:!fir.real<4>>, !fir.class>) {pass_arg_pos = 1 : i32} + // AARCH64: fir.dispatch "with_complex2"(%{{.*}} : !fir.class>) (%{{.*}}, %{{.*}} : !fir.array<2x!fir.real<4>>, !fir.class>) {pass_arg_pos = 1 : i32} + // PPC: fir.dispatch "with_complex2"(%{{.*}} : !fir.class>) (%{{.*}}, %{{.*}}, %{{.*}} : !fir.real<4>, !fir.real<4>, !fir.class>) {pass_arg_pos = 2 : i32} + // SPARCV9: fir.dispatch "with_complex2"(%{{.*}} : !fir.class>) (%{{.*}}, %{{.*}}, %{{.*}} : !fir.real<4>, !fir.real<4>, !fir.class>) {pass_arg_pos = 2 : i32} + fir.dispatch "with_complex2"(%arg0 : !fir.class>) (%2, %arg0 : !fir.complex<4>, !fir.class>) {pass_arg_pos = 1 : i32} + return }