diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -2327,22 +2327,30 @@ let description = [{ Perform a dynamic dispatch on the method name via the dispatch table - associated with the first argument. The attribute 'pass_arg_pos' can be - used to select a dispatch argument other than the first one. + associated with the first operand. The attribute `pass_arg_pos` can be + used to select a dispatch operand other than the first one. The absence of + `pass_arg_pos` attribute means nopass. ```mlir - %r = fir.dispatch methodA(%o) : (!fir.box) -> i32 + // fir.dispatch with no attribute. + %r = fir.dispatch "methodA"(%o) : (!fir.class) -> i32 + + // fir.dispatch with the `pass_arg_pos` attribute. + %r = fir.dispatch "methodA"(%o, %o) : (!fir.class, !fir.class) -> i32 {pass_arg_pos = 0 : i32} ``` }]; let arguments = (ins StrAttr:$method, - fir_BoxType:$object, - Variadic:$args + fir_ClassType:$object, + Variadic:$args, + OptionalAttr:$pass_arg_pos ); let results = (outs Variadic); + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ @@ -2350,14 +2358,10 @@ operand_range getArgOperands() { return {arg_operand_begin(), arg_operand_end()}; } - // operand[0] is the object (of box type) + // operand[0] is the object (of class type) operand_iterator arg_operand_begin() { return operand_begin() + 1; } operand_iterator arg_operand_end() { return operand_end(); } - static constexpr llvm::StringRef getPassArgAttrName() { - return "pass_arg_pos"; - } static constexpr llvm::StringRef getMethodAttrNameStr() { return "method"; } - unsigned passArgPos(); }]; } diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -1038,6 +1038,20 @@ // DispatchOp //===----------------------------------------------------------------------===// +mlir::LogicalResult fir::DispatchOp::verify() { + // Check that pass_arg_pos is in range of actual operands. pass_arg_pos is + // unsigned so check for less than zero is not needed. + if (getPassArgPos() && *getPassArgPos() > (getArgOperands().size() - 1)) + return emitOpError( + "pass_arg_pos must be smaller than the number of operands"); + + // Operand pointed by pass_arg_pos must have polymorphic type. + if (getPassArgPos() && + !fir::isPolymorphicType(getArgOperands()[*getPassArgPos()].getType())) + return emitOpError("pass_arg_pos must be a polymorphic operand"); + return mlir::success(); +} + mlir::FunctionType fir::DispatchOp::getFunctionType() { return mlir::FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); @@ -1060,11 +1074,11 @@ parser.getBuilder().getStringAttr(calleeName)); } if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) || - parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(calleeType) || parser.addTypesToList(calleeType.getResults(), result.types) || parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc, - result.operands)) + result.operands) || + parser.parseOptionalAttrDict(result.attributes)) return mlir::failure(); return mlir::success(); } @@ -1079,6 +1093,9 @@ p << ") : "; p.printFunctionalType(getOperation()->getOperandTypes(), getOperation()->getResultTypes()); + p.printOptionalAttrDict(getOperation()->getAttrs(), + {mlir::SymbolTable::getSymbolAttrName(), + fir::DispatchOp::getMethodAttrNameStr()}); } //===----------------------------------------------------------------------===// diff --git a/flang/test/Fir/Todo/dispatch.fir b/flang/test/Fir/Todo/dispatch.fir --- a/flang/test/Fir/Todo/dispatch.fir +++ b/flang/test/Fir/Todo/dispatch.fir @@ -3,8 +3,8 @@ // Test `fir.dispatch` conversion to llvm. // Not implemented yet. -func.func @dispatch(%arg0: !fir.box>) { -// CHECK: not yet implemented: fir.dispatch codegen - %0 = fir.dispatch "method"(%arg0) : (!fir.box>) -> i32 +func.func @dispatch(%arg0: !fir.class>) { +// CHECK: not yet implemented: fir.class type conversion + %0 = fir.dispatch "method"(%arg0) : (!fir.class>) -> i32 return } diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir --- a/flang/test/Fir/fir-ops.fir +++ b/flang/test/Fir/fir-ops.fir @@ -114,14 +114,14 @@ %25 = fir.insert_value %22, %cf1, ["f", !fir.type] : (!fir.type, f32) -> !fir.type %26 = fir.len_param_index f, !fir.type -// CHECK: [[VAL_31:%.*]] = fir.call @box3() : () -> !fir.box> -// CHECK: [[VAL_32:%.*]] = fir.dispatch "method"([[VAL_31]]) : (!fir.box>) -> i32 +// CHECK: [[VAL_31:%.*]] = fir.call @box3() : () -> !fir.class> +// CHECK: [[VAL_32:%.*]] = fir.dispatch "method"([[VAL_31]]) : (!fir.class>) -> i32 // CHECK: [[VAL_33:%.*]] = fir.convert [[VAL_32]] : (i32) -> i64 // CHECK: [[VAL_34:%.*]] = fir.gentypedesc !fir.type // CHECK: fir.call @user_tdesc([[VAL_34]]) : (!fir.tdesc>) -> () // CHECK: [[VAL_35:%.*]] = fir.no_reassoc [[VAL_33]] : i64 - %27 = fir.call @box3() : () -> !fir.box> - %28 = fir.dispatch "method"(%27) : (!fir.box>) -> i32 + %27 = fir.call @box3() : () -> !fir.class> + %28 = fir.dispatch "method"(%27) : (!fir.class>) -> i32 %29 = fir.convert %28 : (i32) -> i64 %30 = fir.gentypedesc !fir.type fir.call @user_tdesc(%30) : (!fir.tdesc>) -> () @@ -309,12 +309,12 @@ // CHECK: ^bb5: // CHECK: [[VAL_99:%.*]] = arith.constant 0 : i32 -// CHECK: [[VAL_100:%.*]] = fir.call @get_method_box() : () -> !fir.box> -// CHECK: fir.dispatch "method"([[VAL_100]]) : (!fir.box>) -> () +// CHECK: [[VAL_100:%.*]] = fir.call @get_method_box() : () -> !fir.class> +// CHECK: fir.dispatch "method"([[VAL_100]]) : (!fir.class>) -> () ^bb5 : %zero = arith.constant 0 : i32 - %7 = fir.call @get_method_box() : () -> !fir.box> - fir.dispatch method(%7) : (!fir.box>) -> () + %7 = fir.call @get_method_box() : () -> !fir.class> + fir.dispatch method(%7) : (!fir.class>) -> () // CHECK: return [[VAL_99]] : i32 // CHECK: } @@ -805,3 +805,17 @@ // CHECK: %{{.*}} = fir.array_amend %{{.*}}, %{{.*}} : (!fir.array, !fir.ref) -> !fir.array return } + +func.func private @dispatch(%arg0: !fir.class>, %arg1: i32) -> () { + // CHECK-LABEL: func.func private @dispatch( + // CHECK-SAME: %[[CLASS:.*]]: !fir.class>, %[[INTARG:.*]]: i32) + fir.dispatch "proc1"(%arg0, %arg0) : (!fir.class>, !fir.class>) -> () {pass_arg_pos = 0 : i32} + // CHECK: fir.dispatch "proc1"(%[[CLASS]], %[[CLASS]]) : (!fir.class>, !fir.class>) -> () {pass_arg_pos = 0 : i32} + + fir.dispatch "proc2"(%arg0) : (!fir.class>) -> () {nopass} + // CHECK: fir.dispatch "proc2"(%[[CLASS]]) : (!fir.class>) -> () {nopass} + + fir.dispatch "proc3"(%arg0, %arg1, %arg0) : (!fir.class>, i32, !fir.class>) -> () {pass_arg_pos = 1 : i32} + // CHECK: fir.dispatch "proc3"(%[[CLASS]], %[[INTARG]], %[[CLASS]]) : (!fir.class>, i32, !fir.class>) -> () {pass_arg_pos = 1 : i32} + return +} diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir --- a/flang/test/Fir/invalid.fir +++ b/flang/test/Fir/invalid.fir @@ -756,3 +756,19 @@ return } func.func private @ifoo(!fir.ref) -> i32 + +// ----- + +func.func private @dispatch(%arg0: !fir.class>) -> () { + // expected-error@+1 {{'fir.dispatch' op pass_arg_pos must be smaller than the number of operands}} + fir.dispatch "proc1"(%arg0, %arg0) : (!fir.class>, !fir.class>) -> () {pass_arg_pos = 1 : i32} + return +} + +// ----- + +func.func private @dispatch(%arg0: !fir.class>, %arg1: i32) -> () { + // expected-error@+1 {{'fir.dispatch' op pass_arg_pos must be a polymorphic operand}} + fir.dispatch "proc1"(%arg0, %arg0, %arg1) : (!fir.class>, !fir.class>, i32) -> () {pass_arg_pos = 1 : i32} + return +}