diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -488,6 +488,9 @@ mlir::Value getSelector(llvm::ArrayRef operands) { return operands[0]; } + mlir::Value getSelector(mlir::ValueRange operands) { + return operands.front(); + } // The number of blocks that may be branched to unsigned getNumDest() { return (*this)->getNumSuccessors(); } @@ -498,6 +501,8 @@ llvm::Optional> getSuccessorOperands( llvm::ArrayRef operands, unsigned cond); + llvm::Optional getSuccessorOperands( + mlir::ValueRange operands, unsigned cond); using BranchOpInterfaceTrait::getSuccessorOperands; // Helper function to deal with Optional operand forms diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -27,6 +27,8 @@ // fir::LLVMTypeConverter for converting to LLVM IR dialect types. #include "TypeConverter.h" +using OperandTy = ArrayRef; + namespace { /// FIR conversion pattern template template @@ -46,6 +48,13 @@ }; } // namespace +static Block *createBlock(mlir::ConversionPatternRewriter &rewriter, + mlir::Block *insertBefore) { + assert(insertBefore && "expected valid insertion block"); + return rewriter.createBlock(insertBefore->getParent(), + mlir::Region::iterator(insertBefore)); +} + namespace { struct AddrOfOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; @@ -149,6 +158,94 @@ } }; +void genCondBrOp(mlir::Location loc, mlir::Value cmp, mlir::Block *dest, + Optional destOps, + mlir::ConversionPatternRewriter &rewriter, + mlir::Block *newBlock) { + if (destOps.hasValue()) + rewriter.create(loc, cmp, dest, destOps.getValue(), + newBlock, mlir::ValueRange()); + else + rewriter.create(loc, cmp, dest, newBlock); +} + +template +void genBrOp(A caseOp, mlir::Block *dest, Optional destOps, + mlir::ConversionPatternRewriter &rewriter) { + if (destOps.hasValue()) + rewriter.replaceOpWithNewOp(caseOp, destOps.getValue(), + dest); + else + rewriter.replaceOpWithNewOp(caseOp, llvm::None, dest); +} + +void genCaseLadderStep(mlir::Location loc, mlir::Value cmp, mlir::Block *dest, + Optional destOps, + mlir::ConversionPatternRewriter &rewriter) { + auto *thisBlock = rewriter.getInsertionBlock(); + auto *newBlock = createBlock(rewriter, dest); + rewriter.setInsertionPointToEnd(thisBlock); + genCondBrOp(loc, cmp, dest, destOps, rewriter, newBlock); + rewriter.setInsertionPointToEnd(newBlock); +} + +template +void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select, + typename OP::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) { + // We could target the LLVM switch instruction, but it isn't part of the + // LLVM IR dialect. Create an if-then-else ladder instead. + auto conds = select.getNumConditions(); + auto attrName = OP::getCasesAttr(); + auto caseAttr = select->template getAttrOfType(attrName); + auto cases = caseAttr.getValue(); + auto ty = select.getSelector().getType(); + auto ity = lowering.convertType(ty); + auto selector = select.getSelector(adaptor.getOperands()); + auto loc = select.getLoc(); + assert(conds > 0 && "select must have cases"); + for (decltype(conds) t = 0; t != conds; ++t) { + mlir::Block *dest = select.getSuccessor(t); + auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t); + auto &attr = cases[t]; + if (auto intAttr = attr.template dyn_cast()) { + auto ci = rewriter.create( + loc, ity, rewriter.getIntegerAttr(ty, intAttr.getInt())); + auto cmp = rewriter.create( + loc, mlir::LLVM::ICmpPredicate::eq, selector, ci); + genCaseLadderStep(loc, cmp, dest, destOps, rewriter); + continue; + } + assert(attr.template dyn_cast_or_null()); + assert((t + 1 == conds) && "unit must be last"); + genBrOp(select, dest, destOps, rewriter); + } +} + +/// conversion of fir::SelectOp to an if-then-else ladder +struct SelectOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + selectMatchAndRewrite(lowerTy(), op, adaptor, rewriter); + return success(); + } +}; + +/// conversion of fir::SelectRankOp to an if-then-else ladder +struct SelectRankOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + selectMatchAndRewrite(lowerTy(), op, adaptor, rewriter); + return success(); + } +}; + // convert to LLVM IR dialect `undef` struct UndefOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; @@ -179,8 +276,10 @@ fir::LLVMTypeConverter typeConverter{getModule()}; auto loc = mlir::UnknownLoc::get(context); mlir::OwningRewritePatternList pattern(context); - pattern.insert(typeConverter); + pattern + .insert( + typeConverter); mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, pattern); diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -2323,6 +2323,15 @@ return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } +llvm::Optional +fir::SelectOp::getSuccessorOperands(mlir::ValueRange operands, unsigned oper) { + auto a = + (*this)->getAttrOfType(getTargetOffsetAttr()); + auto segments = (*this)->getAttrOfType( + getOperandSegmentSizeAttr()); + return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; +} + unsigned fir::SelectOp::targetOffsetSize() { return denseElementsSize((*this)->getAttrOfType( getTargetOffsetAttr())); @@ -2616,6 +2625,16 @@ return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } +llvm::Optional +fir::SelectRankOp::getSuccessorOperands(mlir::ValueRange operands, + unsigned oper) { + auto a = + (*this)->getAttrOfType(getTargetOffsetAttr()); + auto segments = (*this)->getAttrOfType( + getOperandSegmentSizeAttr()); + return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; +} + unsigned fir::SelectRankOp::targetOffsetSize() { return denseElementsSize((*this)->getAttrOfType( getTargetOffsetAttr())); diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -81,3 +81,158 @@ // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<1> : vector<32x32xi32>) : !llvm.array<32 x array<32 x i32>> // CHECK: llvm.return %[[CST]] : !llvm.array<32 x array<32 x i32>> // CHECK: } + +// ----- + +// Test `fir.select` operation conversion pattern. +// Check that the if-then-else ladder is correctly constructed and that we +// branch to the correct block. + +func @select(%arg : index, %arg2 : i32) -> i32 { + %0 = arith.constant 1 : i32 + %1 = arith.constant 2 : i32 + %2 = arith.constant 3 : i32 + %3 = arith.constant 4 : i32 + fir.select %arg:index [ 1, ^bb1(%0:i32), + 2, ^bb2(%2,%arg,%arg2:i32,index,i32), + 3, ^bb3(%arg2,%2:i32,i32), + 4, ^bb4(%1:i32), + unit, ^bb5 ] + ^bb1(%a : i32) : + return %a : i32 + ^bb2(%b : i32, %b2 : index, %b3:i32) : + %castidx = arith.index_cast %b2 : index to i32 + %4 = arith.addi %b, %castidx : i32 + %5 = arith.addi %4, %b3 : i32 + return %5 : i32 + ^bb3(%c:i32, %c2:i32) : + %6 = arith.addi %c, %c2 : i32 + return %6 : i32 + ^bb4(%d : i32) : + return %d : i32 + ^bb5 : + %zero = arith.constant 0 : i32 + return %zero : i32 +} + +// CHECK-LABEL: func @select( +// CHECK-SAME: %[[SELECTVALUE:.*]]: [[IDX:.*]], +// CHECK-SAME: %[[ARG1:.*]]: i32) +// CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK: %[[CST3:.*]] = llvm.mlir.constant(3 : i32) : i32 +// CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i32) : i32 +// CHECK: %[[MATCH1:.*]] = llvm.mlir.constant(1 : index) : [[IDX]] +// CHECK: %[[CMP1:.*]] = llvm.icmp "eq" %[[SELECTVALUE]], %[[MATCH1]] : [[IDX]] +// CHECK: llvm.cond_br %[[CMP1]], ^bb2(%[[CST1]] : i32), ^bb1 +// CHECK-LABEL: ^bb1: +// CHECK: %[[MATCH2:.*]] = llvm.mlir.constant(2 : index) : [[IDX]] +// CHECK: %[[CMP2:.*]] = llvm.icmp "eq" %[[SELECTVALUE]], %[[MATCH2]] : [[IDX]] +// CHECK: llvm.cond_br %[[CMP2]], ^bb4(%[[CST3]], %[[SELECTVALUE]], %[[ARG1]] : i32, [[IDX]], i32), ^bb3 +// Initially work done by select 1 (^bb1) +// CHECK-LABEL: ^bb2(%{{.*}}: i32) +// CHECK: llvm.return %{{.*}} : i32 +// CHECK-LABEL: ^bb3: +// CHECK: %[[MATCH3:.*]] = llvm.mlir.constant(3 : index) : [[IDX]] +// CHECK: %[[CMP3:.*]] = llvm.icmp "eq" %[[SELECTVALUE]], %[[MATCH3]] : [[IDX]] +// CHECK: llvm.cond_br %[[CMP3]], ^bb6(%[[ARG1]], %[[CST3]] : i32, i32), ^bb5 +// Initially work done by select 2 (^bb2) +// CHECK: ^bb4(%{{.*}}: i32, %{{.*}}: [[IDX]], %{{.*}}: i32): // pred: ^bb1 +// CHECK: %{{.*}} = llvm.trunc %{{.*}} : [[IDX]] to i32 +// CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} : i32 +// CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} : i32 +// CHECK: llvm.return %{{.*}} : i32 +// CHECK-LABEL: ^bb5: +// CHECK: %[[MATCH4:.*]] = llvm.mlir.constant(4 : index) : [[IDX]] +// CHECK: %[[CMP4:.*]] = llvm.icmp "eq" %[[SELECTVALUE]], %[[MATCH4]] : [[IDX]] +// CHECK: llvm.cond_br %[[CMP4]], ^bb8(%1 : i32), ^bb7 +// Initially work done by select 3 (^bb3) +// CHECK-LABEL: ^bb6(%{{.*}}: i32, %{{.*}}: i32): +// CHECK: %[[ADD:.*]] = llvm.add %{{.*}}, %{{.*}} : i32 +// CHECK: llvm.return %{{.*}} : i32 +// CHECK-LABEL: ^bb7: +// CHECK: llvm.br ^bb9 +// Initially work done by select 4 (^bb4) +// CHECK-LABEL: ^bb8(%{{.*}}: i32): +// CHECK: llvm.return %{{.*}} : i32 +// Initially ^bb5 +// CHECK-LABEL: ^bb9: +// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: llvm.return %[[ZERO]] : i32 + +// ----- + +// Test `fir.select_rank` operation conversion pattern. +// Check that the if-then-else ladder is correctly constructed and that we +// branch to the correct block. + +func @select_rank(%arg : i32, %arg2 : i32) -> i32 { + %0 = arith.constant 1 : i32 + %1 = arith.constant 2 : i32 + %2 = arith.constant 3 : i32 + %3 = arith.constant 4 : i32 + fir.select_rank %arg:i32 [ 1, ^bb1(%0:i32), + 2, ^bb2(%2,%arg,%arg2:i32,i32,i32), + 3, ^bb3(%arg2,%2:i32,i32), + 4, ^bb4(%1:i32), + unit, ^bb5 ] + ^bb1(%a : i32) : + return %a : i32 + ^bb2(%b : i32, %b2 : i32, %b3:i32) : + %4 = arith.addi %b, %b2 : i32 + %5 = arith.addi %4, %b3 : i32 + return %5 : i32 + ^bb3(%c:i32, %c2:i32) : + %6 = arith.addi %c, %c2 : i32 + return %6 : i32 + ^bb4(%d : i32) : + return %d : i32 + ^bb5 : + %zero = arith.constant 0 : i32 + return %zero : i32 +} + +// CHECK-LABEL: func @select_rank( +// CHECK-SAME: %[[SELECTVALUE:.*]]: i32, +// CHECK-SAME: %[[ARG1:.*]]: i32) +// CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK: %[[CST3:.*]] = llvm.mlir.constant(3 : i32) : i32 +// CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i32) : i32 +// CHECK: %[[MATCH1:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[CMP1:.*]] = llvm.icmp "eq" %[[SELECTVALUE]], %[[MATCH1]] : i32 +// CHECK: llvm.cond_br %[[CMP1]], ^bb2(%[[CST1]] : i32), ^bb1 +// CHECK-LABEL: ^bb1: +// CHECK: %[[MATCH2:.*]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK: %[[CMP2:.*]] = llvm.icmp "eq" %[[SELECTVALUE]], %[[MATCH2]] : i32 +// CHECK: llvm.cond_br %[[CMP2]], ^bb4(%[[CST3]], %[[SELECTVALUE]], %[[ARG1]] : i32, i32, i32), ^bb3 +// Initially work done by select 1 (^bb1) +// CHECK-LABEL: ^bb2(%{{.*}}: i32) +// CHECK: llvm.return %{{.*}} : i32 +// CHECK-LABEL: ^bb3: +// CHECK: %[[MATCH3:.*]] = llvm.mlir.constant(3 : i32) : i32 +// CHECK: %[[CMP3:.*]] = llvm.icmp "eq" %[[SELECTVALUE]], %[[MATCH3]] : i32 +// CHECK: llvm.cond_br %[[CMP3]], ^bb6(%[[ARG1]], %[[CST3]] : i32, i32), ^bb5 +// Initially work done by select 2 (^bb2) +// CHECK-LABEL: ^bb4(%{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32): // pred: ^bb1 +// CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} : i32 +// CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} : i32 +// CHECK: llvm.return %{{.*}} : i32 +// CHECK-LABEL: ^bb5: +// CHECK: %[[MATCH4:.*]] = llvm.mlir.constant(4 : i32) : i32 +// CHECK: %[[CMP4:.*]] = llvm.icmp "eq" %[[SELECTVALUE]], %[[MATCH4]] : i32 +// CHECK: llvm.cond_br %[[CMP4]], ^bb8(%1 : i32), ^bb7 +// Initially work done by select 3 (^bb3) +// CHECK-LABEL: ^bb6(%{{.*}}: i32, %{{.*}}: i32): +// CHECK: %[[ADD:.*]] = llvm.add %{{.*}}, %{{.*}} : i32 +// CHECK: llvm.return %{{.*}} : i32 +// CHECK-LABEL: ^bb7: +// CHECK: llvm.br ^bb9 +// Initially work done by select 4 (^bb4) +// CHECK-LABEL: ^bb8(%{{.*}}: i32): +// CHECK: llvm.return %{{.*}} : i32 +// Initially ^bb5 +// CHECK-LABEL: ^bb9: +// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: llvm.return %[[ZERO]] : i32 +