diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -488,6 +488,9 @@ mlir::Value getSelector(llvm::ArrayRef operands) { return operands[0]; } + mlir::Value getSelector(mlir::ValueRange operands) { + return operands.front(); + } // The number of blocks that may be branched to unsigned getNumDest() { return (*this)->getNumSuccessors(); } @@ -498,6 +501,8 @@ llvm::Optional> getSuccessorOperands( llvm::ArrayRef operands, unsigned cond); + llvm::Optional getSuccessorOperands( + mlir::ValueRange operands, unsigned cond); using BranchOpInterfaceTrait::getSuccessorOperands; // Helper function to deal with Optional operand forms diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -149,6 +149,80 @@ } }; +template +void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select, + typename OP::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) { + auto conds = select.getNumConditions(); + auto caseAttr = + select->template getAttrOfType(OP::getCasesAttr()); + auto cases = caseAttr.getValue(); + auto selector = select.getSelector(adaptor.getOperands()); + auto loc = select.getLoc(); + assert(conds > 0 && "select must have cases"); + + llvm::SmallVector destinations; + llvm::SmallVector destinationsOperands; + mlir::Block *defaultDestination; + mlir::ValueRange defaultOperands; + llvm::SmallVector caseValues; + + for (decltype(conds) t = 0; t != conds; ++t) { + mlir::Block *dest = select.getSuccessor(t); + auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t); + auto &attr = cases[t]; + if (auto intAttr = attr.template dyn_cast()) { + destinations.push_back(dest); + destinationsOperands.push_back(destOps.hasValue() ? *destOps + : ValueRange()); + caseValues.push_back(intAttr.getInt()); + continue; + } + assert(attr.template dyn_cast_or_null()); + assert((t + 1 == conds) && "unit must be last"); + defaultDestination = dest; + defaultOperands = destOps.hasValue() ? *destOps : ValueRange(); + } + + // LLVM::SwitchOp takes a i32 type for the selector. + if (select.getSelector().getType() != rewriter.getI32Type()) + selector = + rewriter.create(loc, rewriter.getI32Type(), selector); + + rewriter.replaceOpWithNewOp( + select, selector, + /*defaultDestination=*/defaultDestination, + /*defaultOperands=*/defaultOperands, + /*caseValues=*/caseValues, + /*caseDestinations=*/destinations, + /*caseOperands=*/destinationsOperands, + /*branchWeights=*/ArrayRef()); +} + +/// conversion of fir::SelectOp to an if-then-else ladder +struct SelectOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + selectMatchAndRewrite(lowerTy(), op, adaptor, rewriter); + return success(); + } +}; + +/// conversion of fir::SelectRankOp to an if-then-else ladder +struct SelectRankOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + selectMatchAndRewrite(lowerTy(), op, adaptor, rewriter); + return success(); + } +}; + // convert to LLVM IR dialect `undef` struct UndefOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; @@ -179,8 +253,10 @@ fir::LLVMTypeConverter typeConverter{getModule()}; auto loc = mlir::UnknownLoc::get(context); mlir::OwningRewritePatternList pattern(context); - pattern.insert(typeConverter); + pattern + .insert( + typeConverter); mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, pattern); diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -2323,6 +2323,15 @@ return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } +llvm::Optional +fir::SelectOp::getSuccessorOperands(mlir::ValueRange operands, unsigned oper) { + auto a = + (*this)->getAttrOfType(getTargetOffsetAttr()); + auto segments = (*this)->getAttrOfType( + getOperandSegmentSizeAttr()); + return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; +} + unsigned fir::SelectOp::targetOffsetSize() { return denseElementsSize((*this)->getAttrOfType( getTargetOffsetAttr())); @@ -2616,6 +2625,16 @@ return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } +llvm::Optional +fir::SelectRankOp::getSuccessorOperands(mlir::ValueRange operands, + unsigned oper) { + auto a = + (*this)->getAttrOfType(getTargetOffsetAttr()); + auto segments = (*this)->getAttrOfType( + getOperandSegmentSizeAttr()); + return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; +} + unsigned fir::SelectRankOp::targetOffsetSize() { return denseElementsSize((*this)->getAttrOfType( getTargetOffsetAttr())); diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -81,3 +81,95 @@ // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<1> : vector<32x32xi32>) : !llvm.array<32 x array<32 x i32>> // CHECK: llvm.return %[[CST]] : !llvm.array<32 x array<32 x i32>> // CHECK: } + +// ----- + +// Test `fir.select` operation conversion pattern. +// Check that the if-then-else ladder is correctly constructed and that we +// branch to the correct block. + +func @select(%arg : index, %arg2 : i32) -> i32 { + %0 = arith.constant 1 : i32 + %1 = arith.constant 2 : i32 + %2 = arith.constant 3 : i32 + %3 = arith.constant 4 : i32 + fir.select %arg:index [ 1, ^bb1(%0:i32), + 2, ^bb2(%2,%arg,%arg2:i32,index,i32), + 3, ^bb3(%arg2,%2:i32,i32), + 4, ^bb4(%1:i32), + unit, ^bb5 ] + ^bb1(%a : i32) : + return %a : i32 + ^bb2(%b : i32, %b2 : index, %b3:i32) : + %castidx = arith.index_cast %b2 : index to i32 + %4 = arith.addi %b, %castidx : i32 + %5 = arith.addi %4, %b3 : i32 + return %5 : i32 + ^bb3(%c:i32, %c2:i32) : + %6 = arith.addi %c, %c2 : i32 + return %6 : i32 + ^bb4(%d : i32) : + return %d : i32 + ^bb5 : + %zero = arith.constant 0 : i32 + return %zero : i32 +} + +// CHECK-LABEL: func @select( +// CHECK-SAME: %[[SELECTVALUE:.*]]: [[IDX:.*]], +// CHECK-SAME: %[[ARG1:.*]]: i32) +// CHECK: %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK: %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32 +// CHECK: %[[SELECTOR:.*]] = llvm.trunc %[[SELECTVALUE]] : i{{.*}} to i32 +// CHECK: llvm.switch %[[SELECTOR]], ^bb5 [ +// CHECK: 1: ^bb1(%[[C0]] : i32), +// CHECK: 2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, [[IDX]], i32), +// CHECK: 3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32), +// CHECK: 4: ^bb4(%[[C1]] : i32) +// CHECK: ] + +// ----- + +// Test `fir.select_rank` operation conversion pattern. +// Check that the if-then-else ladder is correctly constructed and that we +// branch to the correct block. + +func @select_rank(%arg : i32, %arg2 : i32) -> i32 { + %0 = arith.constant 1 : i32 + %1 = arith.constant 2 : i32 + %2 = arith.constant 3 : i32 + %3 = arith.constant 4 : i32 + fir.select_rank %arg:i32 [ 1, ^bb1(%0:i32), + 2, ^bb2(%2,%arg,%arg2:i32,i32,i32), + 3, ^bb3(%arg2,%2:i32,i32), + 4, ^bb4(%1:i32), + unit, ^bb5 ] + ^bb1(%a : i32) : + return %a : i32 + ^bb2(%b : i32, %b2 : i32, %b3:i32) : + %4 = arith.addi %b, %b2 : i32 + %5 = arith.addi %4, %b3 : i32 + return %5 : i32 + ^bb3(%c:i32, %c2:i32) : + %6 = arith.addi %c, %c2 : i32 + return %6 : i32 + ^bb4(%d : i32) : + return %d : i32 + ^bb5 : + %zero = arith.constant 0 : i32 + return %zero : i32 +} + +// CHECK-LABEL: func @select_rank( +// CHECK-SAME: %[[SELECTVALUE:.*]]: i32, +// CHECK-SAME: %[[ARG1:.*]]: i32) +// CHECK: %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK: %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32 +// CHECK: llvm.switch %[[SELECTVALUE]], ^bb5 [ +// CHECK: 1: ^bb1(%[[C0]] : i32), +// CHECK: 2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, i32, i32), +// CHECK: 3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32), +// CHECK: 4: ^bb4(%[[C1]] : i32) +// CHECK: ]