diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -498,6 +498,8 @@ llvm::Optional getCompareOperands(unsigned cond); llvm::Optional> getCompareOperands( llvm::ArrayRef operands, unsigned cond); + llvm::Optional getCompareOperands( + mlir::ValueRange operands, unsigned cond); llvm::Optional> getSuccessorOperands( llvm::ArrayRef operands, unsigned cond); diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -13,6 +13,7 @@ #include "flang/Optimizer/CodeGen/CodeGen.h" #include "PassDetail.h" #include "flang/ISO_Fortran_binding.h" +#include "flang/Optimizer/Dialect/FIRAttr.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" @@ -40,6 +41,13 @@ return rewriter.create(loc, ity, cattr); } +static Block *createBlock(mlir::ConversionPatternRewriter &rewriter, + mlir::Block *insertBefore) { + assert(insertBefore && "expected valid insertion block"); + return rewriter.createBlock(insertBefore->getParent(), + mlir::Region::iterator(insertBefore)); +} + namespace { /// FIR conversion pattern template template @@ -695,6 +703,122 @@ } }; +void genCondBrOp(mlir::Location loc, mlir::Value cmp, mlir::Block *dest, + Optional destOps, + mlir::ConversionPatternRewriter &rewriter, + mlir::Block *newBlock) { + if (destOps.hasValue()) + rewriter.create(loc, cmp, dest, destOps.getValue(), + newBlock, mlir::ValueRange()); + else + rewriter.create(loc, cmp, dest, newBlock); +} + +template +void genBrOp(A caseOp, mlir::Block *dest, Optional destOps, + mlir::ConversionPatternRewriter &rewriter) { + if (destOps.hasValue()) + rewriter.replaceOpWithNewOp(caseOp, destOps.getValue(), + dest); + else + rewriter.replaceOpWithNewOp(caseOp, llvm::None, dest); +} + +void genCaseLadderStep(mlir::Location loc, mlir::Value cmp, mlir::Block *dest, + Optional destOps, + mlir::ConversionPatternRewriter &rewriter) { + auto *thisBlock = rewriter.getInsertionBlock(); + auto *newBlock = createBlock(rewriter, dest); + rewriter.setInsertionPointToEnd(thisBlock); + genCondBrOp(loc, cmp, dest, destOps, rewriter, newBlock); + rewriter.setInsertionPointToEnd(newBlock); +} + +/// Conversion of `fir.select_case` +/// +/// The `fir.select_case` operation is converted to a if-then-else ladder. +/// Depending on the case condition type, one or several comparison and +/// conditional branching can be generated. +/// +/// A a point value case such as `case(4)`, a lower bound case such as +/// `case(5:)` or an upper bound case such as `case(:3)` are converted to a +/// simple comparison between the selector value and the constant value in the +/// case. The block associated with the case condition is then executed if +/// the comparison succeed otherwise it branch to the next block with the +/// comparison for the the next case conditon. +/// +/// A closed interval case condition such as `case(7:10)` is converted with a +/// first comparison and conditional branching for the lower bound. If +/// successful, it branch to a second block with the comparison for the +/// upper bound in the same case condition. +/// +/// TODO: lowering of CHARACTER type cases is not handled yet. +struct SelectCaseOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::SelectCaseOp caseOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + unsigned conds = caseOp.getNumConditions(); + llvm::ArrayRef cases = caseOp.getCases().getValue(); + // Type can be CHARACTER, INTEGER, or LOGICAL (C1145) + LLVM_ATTRIBUTE_UNUSED auto ty = caseOp.getSelector().getType(); + if (ty.isa()) + return rewriter.notifyMatchFailure(caseOp, + "conversion of fir.select_case with " + "character type not implemented yet"); + mlir::Value selector = caseOp.getSelector(adaptor.getOperands()); + auto loc = caseOp.getLoc(); + for (unsigned t = 0; t != conds; ++t) { + mlir::Block *dest = caseOp.getSuccessor(t); + llvm::Optional destOps = + caseOp.getSuccessorOperands(adaptor.getOperands(), t); + llvm::Optional cmpOps = + *caseOp.getCompareOperands(adaptor.getOperands(), t); + mlir::Value caseArg = *(cmpOps.getValue().begin()); + mlir::Attribute attr = cases[t]; + if (attr.isa()) { + auto cmp = rewriter.create( + loc, mlir::LLVM::ICmpPredicate::eq, selector, caseArg); + genCaseLadderStep(loc, cmp, dest, destOps, rewriter); + continue; + } + if (attr.isa()) { + auto cmp = rewriter.create( + loc, mlir::LLVM::ICmpPredicate::sle, caseArg, selector); + genCaseLadderStep(loc, cmp, dest, destOps, rewriter); + continue; + } + if (attr.isa()) { + auto cmp = rewriter.create( + loc, mlir::LLVM::ICmpPredicate::sle, selector, caseArg); + genCaseLadderStep(loc, cmp, dest, destOps, rewriter); + continue; + } + if (attr.isa()) { + auto cmp = rewriter.create( + loc, mlir::LLVM::ICmpPredicate::sle, caseArg, selector); + auto *thisBlock = rewriter.getInsertionBlock(); + auto *newBlock1 = createBlock(rewriter, dest); + auto *newBlock2 = createBlock(rewriter, dest); + rewriter.setInsertionPointToEnd(thisBlock); + rewriter.create(loc, cmp, newBlock1, newBlock2); + rewriter.setInsertionPointToEnd(newBlock1); + mlir::Value caseArg0 = *(cmpOps.getValue().begin() + 1); + auto cmp0 = rewriter.create( + loc, mlir::LLVM::ICmpPredicate::sle, selector, caseArg0); + genCondBrOp(loc, cmp0, dest, destOps, rewriter, newBlock2); + rewriter.setInsertionPointToEnd(newBlock2); + continue; + } + assert(attr.isa()); + assert((t + 1 == conds) && "unit must be last"); + genBrOp(caseOp, dest, destOps, rewriter); + } + return success(); + } +}; + template void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select, typename OP::Adaptor adaptor, @@ -1233,9 +1357,9 @@ DivcOpConversion, ExtractValueOpConversion, HasValueOpConversion, GlobalOpConversion, InsertOnRangeOpConversion, InsertValueOpConversion, LoadOpConversion, NegcOpConversion, MulcOpConversion, - SelectOpConversion, SelectRankOpConversion, StoreOpConversion, - SubcOpConversion, UndefOpConversion, UnreachableOpConversion, - ZeroOpConversion>(typeConverter); + SelectCaseOpConversion, SelectOpConversion, SelectRankOpConversion, + StoreOpConversion, SubcOpConversion, UndefOpConversion, + UnreachableOpConversion, ZeroOpConversion>(typeConverter); mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, pattern); diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -2297,6 +2297,16 @@ return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; } +llvm::Optional +fir::SelectCaseOp::getCompareOperands(mlir::ValueRange operands, + unsigned cond) { + auto a = (*this)->getAttrOfType( + getCompareOffsetAttr()); + auto segments = (*this)->getAttrOfType( + getOperandSegmentSizeAttr()); + return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; +} + llvm::Optional fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) { return ::getMutableSuccessorOperands(oper, targetArgsMutable(), @@ -2313,6 +2323,16 @@ return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } +llvm::Optional +fir::SelectCaseOp::getSuccessorOperands(mlir::ValueRange operands, + unsigned oper) { + auto a = + (*this)->getAttrOfType(getTargetOffsetAttr()); + auto segments = (*this)->getAttrOfType( + getOperandSegmentSizeAttr()); + return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; +} + // parser for fir.select_case Op static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser, mlir::OperationState &result) { diff --git a/flang/test/Fir/convert-to-llvm-invalid.fir b/flang/test/Fir/convert-to-llvm-invalid.fir --- a/flang/test/Fir/convert-to-llvm-invalid.fir +++ b/flang/test/Fir/convert-to-llvm-invalid.fir @@ -2,6 +2,9 @@ // RUN: fir-opt --split-input-file --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" --verify-diagnostics %s +// Test `fir.zero` conversion failure with aggregate type. +// Not implemented yet. + func @zero_aggregate() { // expected-error@+1{{failed to legalize operation 'fir.zero_bits'}} %a = fir.zero_bits !fir.array<10xf32> @@ -27,3 +30,23 @@ fir.dispatch_table @dispatch_tbl { fir.dt_entry "method", @method_impl } + +// ----- + +// Test `fir.select_case` conversion failure with character type. +// Not implemented yet. + +func @select_case_charachter(%arg0: !fir.char<2, 10>, %arg1: !fir.char<2, 10>, %arg2: !fir.char<2, 10>) { + // expected-error@+1{{failed to legalize operation 'fir.select_case'}} + fir.select_case %arg0 : !fir.char<2, 10> [#fir.point, %arg1, ^bb1, + #fir.point, %arg2, ^bb2, + unit, ^bb3] +^bb1: + %c1_i32 = arith.constant 1 : i32 + br ^bb3 +^bb2: + %c2_i32 = arith.constant 2 : i32 + br ^bb3 +^bb3: + return +} diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -961,3 +961,155 @@ // CHECK: [[PROD3:%.*]] = llvm.mul [[PROD2]], [[B]] : i64 // CHECK: [[RES:%.*]] = llvm.alloca [[PROD3]] x i32 {in_type = !fir.array<4x?x3x?x5xi32> // CHECK: llvm.return [[RES]] : !llvm.ptr + +// ----- + +// Test `fir.select_case` operation conversion with INTEGER. + +func @select_case_integer(%arg0: !fir.ref) -> i32 { + %2 = fir.load %arg0 : !fir.ref + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c4_i32 = arith.constant 4 : i32 + %c5_i32 = arith.constant 5 : i32 + %c7_i32 = arith.constant 7 : i32 + %c8_i32 = arith.constant 8 : i32 + %c15_i32 = arith.constant 15 : i32 + %c21_i32 = arith.constant 21 : i32 + fir.select_case %2 : i32 [#fir.upper, %c1_i32, ^bb1, + #fir.point, %c2_i32, ^bb2, + #fir.interval, %c4_i32, %c5_i32, ^bb4, + #fir.point, %c7_i32, ^bb5, + #fir.interval, %c8_i32, %c15_i32, ^bb5, + #fir.lower, %c21_i32, ^bb5, + unit, ^bb3] +^bb1: // pred: ^bb0 + %c1_i32_0 = arith.constant 1 : i32 + fir.store %c1_i32_0 to %arg0 : !fir.ref + br ^bb6 +^bb2: // pred: ^bb0 + %c2_i32_1 = arith.constant 2 : i32 + fir.store %c2_i32_1 to %arg0 : !fir.ref + br ^bb6 +^bb3: // pred: ^bb0 + %c0_i32 = arith.constant 0 : i32 + fir.store %c0_i32 to %arg0 : !fir.ref + br ^bb6 +^bb4: // pred: ^bb0 + %c4_i32_2 = arith.constant 4 : i32 + fir.store %c4_i32_2 to %arg0 : !fir.ref + br ^bb6 +^bb5: // 3 preds: ^bb0, ^bb0, ^bb0 + %c7_i32_3 = arith.constant 7 : i32 + fir.store %c7_i32_3 to %arg0 : !fir.ref + br ^bb6 +^bb6: // 5 preds: ^bb1, ^bb2, ^bb3, ^bb4, ^bb5 + %3 = fir.load %arg0 : !fir.ref + return %3 : i32 +} + +// CHECK-LABEL: llvm.func @select_case_integer( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr) -> i32 { +// CHECK: %[[SELECT_VALUE:.*]] = llvm.load %[[ARG0]] : !llvm.ptr +// CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i32) : i32 +// CHECK: %[[CST5:.*]] = llvm.mlir.constant(5 : i32) : i32 +// CHECK: %[[CST7:.*]] = llvm.mlir.constant(7 : i32) : i32 +// CHECK: %[[CST8:.*]] = llvm.mlir.constant(8 : i32) : i32 +// CHECK: %[[CST15:.*]] = llvm.mlir.constant(15 : i32) : i32 +// CHECK: %[[CST21:.*]] = llvm.mlir.constant(21 : i32) : i32 +// Check for upper bound `case (:1)` +// CHECK: %[[CMP_SLE:.*]] = llvm.icmp "sle" %[[SELECT_VALUE]], %[[CST1]] : i32 +// CHECK: llvm.cond_br %[[CMP_SLE]], ^bb2, ^bb1 +// CHECK-LABEL: ^bb1: +// Check for point value `case (2)` +// CHECK: %[[CMP_EQ:.*]] = llvm.icmp "eq" %[[SELECT_VALUE]], %[[CST2]] : i32 +// CHECK: llvm.cond_br %[[CMP_EQ]], ^bb4, ^bb3 +// Block ^bb1 in original FIR code. +// CHECK-LABEL: ^bb2: +// CHECK: llvm.br ^bb{{.*}} +// CHECK-LABEL: ^bb3: +// Check for the lower bound for the interval `case (4:5)` +// CHECK: %[[CMP_SLE:.*]] = llvm.icmp "sle" %[[CST4]], %[[SELECT_VALUE]] : i32 +// CHECK: llvm.cond_br %[[CMP_SLE]], ^bb[[UPPERBOUND5:.*]], ^bb7 +// Block ^bb2 in original FIR code. +// CHECK-LABEL: ^bb4: +// CHECK: llvm.br ^bb{{.*}} +// Block ^bb3 in original FIR code. +// CHECK-LABEL: ^bb5: +// CHECK: llvm.br ^bb{{.*}} +// CHECK: ^bb[[UPPERBOUND5]]: +// Check for the upper bound for the interval `case (4:5)` +// CHECK: %[[CMP_SLE:.*]] = llvm.icmp "sle" %[[SELECT_VALUE]], %[[CST5]] : i32 +// CHECK: llvm.cond_br %[[CMP_SLE]], ^bb8, ^bb7 +// CHECK-LABEL: ^bb7: +// Check for the point value 7 in `case (7,8:15,21:)` +// CHECK: %[[CMP_EQ:.*]] = llvm.icmp "eq" %[[SELECT_VALUE]], %[[CST7]] : i32 +// CHECK: llvm.cond_br %[[CMP_EQ]], ^bb13, ^bb9 +// Block ^bb4 in original FIR code. +// CHECK-LABEL: ^bb8: +// CHECK: llvm.br ^bb{{.*}} +// CHECK-LABEL: ^bb9: +// Check for lower bound 8 in `case (7,8:15,21:)` +// CHECK: %[[CMP_SLE:.*]] = llvm.icmp "sle" %[[CST8]], %[[SELECT_VALUE]] : i32 +// CHECK: llvm.cond_br %[[CMP_SLE]], ^bb[[INTERVAL8_15:.*]], ^bb11 +// CHECK: ^bb[[INTERVAL8_15]]: +// Check for upper bound 15 in `case (7,8:15,21:)` +// CHECK: %[[CMP_SLE:.*]] = llvm.icmp "sle" %[[SELECT_VALUE]], %[[CST15]] : i32 +// CHECK: llvm.cond_br %[[CMP_SLE]], ^bb13, ^bb11 +// CHECK-LABEL: ^bb11: +// Check for lower bound 21 in `case (7,8:15,21:)` +// CHECK: %[[CMP_SLE:.*]] = llvm.icmp "sle" %[[CST21]], %[[SELECT_VALUE]] : i32 +// CHECK: llvm.cond_br %[[CMP_SLE]], ^bb13, ^bb12 +// CHECK-LABEL: ^bb12: +// CHECK: llvm.br ^bb5 +// Block ^bb5 in original FIR code. +// CHECK-LABEL: ^bb13: +// CHECK: llvm.br ^bb14 +// Block ^bb6 in original FIR code. +// CHECK-LABEL: ^bb14: +// CHECK: %[[RET:.*]] = llvm.load %[[ARG0:.*]] : !llvm.ptr +// CHECK: llvm.return %[[RET]] : i32 + +// ----- + +// Test `fir.select_case` operation conversion with LOGICAL. + +func @select_case_logical(%arg0: !fir.ref>) { + %1 = fir.load %arg0 : !fir.ref> + %2 = fir.convert %1 : (!fir.logical<4>) -> i1 + %false = arith.constant false + %true = arith.constant true + fir.select_case %2 : i1 [#fir.point, %false, ^bb1, + #fir.point, %true, ^bb2, + unit, ^bb3] +^bb1: + %c1_i32 = arith.constant 1 : i32 + br ^bb3 +^bb2: + %c2_i32 = arith.constant 2 : i32 + br ^bb3 +^bb3: + return +} + +// CHECK-LABEL: llvm.func @select_case_logical( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr +// CHECK: %[[LOAD_ARG0:.*]] = llvm.load %[[ARG0]] : !llvm.ptr +// CHECK: %[[SELECT_VALUE:.*]] = llvm.trunc %[[LOAD_ARG0]] : i32 to i1 +// CHECK: %[[CST_FALSE:.*]] = llvm.mlir.constant(false) : i1 +// CHECK: %[[CST_TRUE:.*]] = llvm.mlir.constant(true) : i1 +// CHECK: %[[CMPEQ:.*]] = llvm.icmp "eq" %[[SELECT_VALUE]], %[[CST_FALSE]] : i1 +// CHECK: llvm.cond_br %[[CMPEQ]], ^bb2, ^bb1 +// CHECK-LABEL: ^bb1: +// CHECK: %[[CMPEQ:.*]] = llvm.icmp "eq" %[[SELECT_VALUE]], %[[CST_TRUE]] : i1 +// CHECK: llvm.cond_br %[[CMPEQ]], ^bb4, ^bb3 +// CHECK-LABEL: ^bb2: +// CHECK: llvm.br ^bb5 +// CHECK-LABEL: ^bb3: +// CHECK: llvm.br ^bb5 +// CHECK-LABEL: ^bb4: +// CHECK: llvm.br ^bb5 +// CHECK-LABEL: ^bb5: +// CHECK: llvm.return