diff --git a/flang/include/flang/Optimizer/Dialect/FIRAttr.h b/flang/include/flang/Optimizer/Dialect/FIRAttr.h --- a/flang/include/flang/Optimizer/Dialect/FIRAttr.h +++ b/flang/include/flang/Optimizer/Dialect/FIRAttr.h @@ -38,7 +38,7 @@ using Base::Base; using ValueType = mlir::Type; - static constexpr llvm::StringRef getAttrName() { return "instance"; } + static constexpr llvm::StringRef getAttrName() { return "type_is"; } static ExactTypeAttr get(mlir::Type value); mlir::Type getType() const; @@ -51,7 +51,7 @@ using Base::Base; using ValueType = mlir::Type; - static constexpr llvm::StringRef getAttrName() { return "subsumed"; } + static constexpr llvm::StringRef getAttrName() { return "class_is"; } static SubclassAttr get(mlir::Type value); mlir::Type getType() const; diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -651,10 +651,10 @@ ```mlir fir.select_type %arg : !fir.box<()> [ - #fir.instance>, ^bb1(%0 : i32), - #fir.instance>, ^bb2(%2 : i32), - #fir.subsumed>, ^bb3(%2 : i32), - #fir.instance>, ^bb4(%1,%3 : i32,f32), + #fir.type_is>, ^bb1(%0 : i32), + #fir.type_is>, ^bb2(%2 : i32), + #fir.class_is>, ^bb3(%2 : i32), + #fir.type_is>, ^bb4(%1,%3 : i32,f32), unit, ^bb5] ``` }]; diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -2970,8 +2970,11 @@ } mlir::LogicalResult fir::SelectTypeOp::verify() { - if (!(getSelector().getType().isa())) - return emitOpError("must be a boxed type"); + if (!(getSelector().getType().isa())) + return emitOpError("must be a fir.class or fir.box type"); + if (auto boxType = getSelector().getType().dyn_cast()) + if (!boxType.getEleTy().isa()) + return emitOpError("selector must be polymorphic"); auto cases = getOperation()->getAttrOfType(getCasesAttr()).getValue(); auto count = getNumDest(); diff --git a/flang/test/Fir/convert-to-llvm-invalid.fir b/flang/test/Fir/convert-to-llvm-invalid.fir --- a/flang/test/Fir/convert-to-llvm-invalid.fir +++ b/flang/test/Fir/convert-to-llvm-invalid.fir @@ -71,14 +71,14 @@ // Test `fir.select_type` conversion to llvm. // Should have been converted. -func.func @bar_select_type(%arg : !fir.box>) -> i32 { +func.func @bar_select_type(%arg : !fir.class>) -> i32 { %0 = arith.constant 1 : i32 %2 = arith.constant 3 : i32 // expected-error@+2{{fir.select_type should have already been converted}} // expected-error@+1{{failed to legalize operation 'fir.select_type'}} - fir.select_type %arg : !fir.box> [ - #fir.instance>,^bb1(%0:i32), - #fir.instance>,^bb2(%2:i32), + fir.select_type %arg : !fir.class> [ + #fir.type_is>,^bb1(%0:i32), + #fir.type_is>,^bb2(%2:i32), unit,^bb5 ] ^bb1(%a : i32) : return %a : i32 diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir --- a/flang/test/Fir/fir-ops.fir +++ b/flang/test/Fir/fir-ops.fir @@ -322,8 +322,8 @@ } // CHECK-LABEL: func @bar_select_type( -// CHECK-SAME: [[VAL_101:%.*]]: !fir.box}>>) -> i32 { -func.func @bar_select_type(%arg : !fir.box}>>) -> i32 { +// CHECK-SAME: [[VAL_101:%.*]]: !fir.class}>>) -> i32 { +func.func @bar_select_type(%arg : !fir.class}>>) -> i32 { // CHECK: [[VAL_102:%.*]] = arith.constant 1 : i32 // CHECK: [[VAL_103:%.*]] = arith.constant 2 : i32 @@ -334,8 +334,8 @@ %2 = arith.constant 3 : i32 %3 = arith.constant 4 : i32 -// CHECK: fir.select_type [[VAL_101]] : !fir.box}>> [#fir.instance>, ^bb1([[VAL_102]] : i32), #fir.instance>, ^bb2([[VAL_104]] : i32), #fir.subsumed>, ^bb3([[VAL_104]] : i32), #fir.instance>, ^bb4([[VAL_103]] : i32), unit, ^bb5] - fir.select_type %arg : !fir.box}>> [ #fir.instance>,^bb1(%0:i32), #fir.instance>,^bb2(%2:i32), #fir.subsumed>,^bb3(%2:i32), #fir.instance>,^bb4(%1:i32), unit,^bb5 ] +// CHECK: fir.select_type [[VAL_101]] : !fir.class}>> [#fir.type_is>, ^bb1([[VAL_102]] : i32), #fir.type_is>, ^bb2([[VAL_104]] : i32), #fir.class_is>, ^bb3([[VAL_104]] : i32), #fir.type_is>, ^bb4([[VAL_103]] : i32), unit, ^bb5] + fir.select_type %arg : !fir.class}>> [ #fir.type_is>,^bb1(%0:i32), #fir.type_is>,^bb2(%2:i32), #fir.class_is>,^bb3(%2:i32), #fir.type_is>,^bb4(%1:i32), unit,^bb5 ] // CHECK: ^bb1([[VAL_106:%.*]]: i32): // CHECK: return [[VAL_106]] : i32 diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir --- a/flang/test/Fir/invalid.fir +++ b/flang/test/Fir/invalid.fir @@ -928,3 +928,22 @@ %0 = fir.declare %arg0(%shape) {uniq_name = "x"} : (!fir.ref>>>, !fir.shift<2>) -> !fir.ref>>> return } + +// ----- + +func.func @invalid_selector(%arg : !fir.box>) -> i32 { + %0 = arith.constant 1 : i32 + %2 = arith.constant 3 : i32 + // expected-error@+1{{'fir.select_type' op selector must be polymorphic}} + fir.select_type %arg : !fir.box> [ + #fir.type_is>,^bb1(%0:i32), + #fir.type_is>,^bb2(%2:i32), + unit,^bb5 ] +^bb1(%a : i32) : + return %a : i32 +^bb2(%b : i32) : + return %b : i32 +^bb5 : + %zero = arith.constant 0 : i32 + return %zero : i32 +}