Index: flang/lib/Optimizer/CMakeLists.txt =================================================================== --- flang/lib/Optimizer/CMakeLists.txt +++ flang/lib/Optimizer/CMakeLists.txt @@ -17,10 +17,12 @@ CodeGen/TargetRewrite.cpp Transforms/Inliner.cpp + Transforms/RewriteLoop.cpp DEPENDS FIROpsIncGen FIROptCodeGenPassIncGen + FIROptTransformsPassIncGen CGOpsIncGen ${dialect_libs} Index: flang/lib/Optimizer/Transforms/PassDetail.h =================================================================== --- /dev/null +++ flang/lib/Optimizer/Transforms/PassDetail.h @@ -0,0 +1,22 @@ +//===- PassDetail.h - Optimizer Transforms Pass class details ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef OPTMIZER_TRANSFORMS_PASSDETAIL_H +#define OPTMIZER_TRANSFORMS_PASSDETAIL_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" + +namespace fir { + +#define GEN_PASS_CLASSES +#include "flang/Optimizer/Transforms/Passes.h.inc" + +} // namespace fir + +#endif // OPTMIZER_TRANSFORMS_PASSDETAIL_H Index: flang/lib/Optimizer/Transforms/RewriteLoop.cpp =================================================================== --- /dev/null +++ flang/lib/Optimizer/Transforms/RewriteLoop.cpp @@ -0,0 +1,324 @@ +//===-- RewriteLoop.cpp ---------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/CommandLine.h" + +/// disable FIR to scf dialect conversion +static llvm::cl::opt + disableCfgConversion("disable-cfg-conversion", + llvm::cl::desc("disable FIR to CFG pass"), + llvm::cl::init(false)); + +/// minimum trip count is 1, not 0 +static llvm::cl::opt forceLoopToExecuteOnce( + "always-execute-loop-body", + llvm::cl::desc("force the body of a loop to execute at least once"), + llvm::cl::init(false)); + +using namespace fir; + +namespace { + +// Conversion of fir control ops to more primitive control-flow. +// +// FIR loops that cannot be converted to the affine dialect will remain as +// `fir.do_loop` operations. These can be converted to control-flow operations. + +/// Convert `fir.do_loop` to CFG +class CfgLoopConv : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(DoLoopOp loop, + mlir::PatternRewriter &rewriter) const override { + auto loc = loop.getLoc(); + + // Create the start and end blocks that will wrap the DoLoopOp with an + // initalizer and an end point + auto *initBlock = rewriter.getInsertionBlock(); + auto initPos = rewriter.getInsertionPoint(); + auto *endBlock = rewriter.splitBlock(initBlock, initPos); + + // Split the first DoLoopOp block in two parts. The part before will be the + // conditional block since it already has the induction variable and + // loop-carried values as arguments. + auto *conditionalBlock = &loop.region().front(); + conditionalBlock->addArgument(rewriter.getIndexType()); + auto *firstBlock = + rewriter.splitBlock(conditionalBlock, conditionalBlock->begin()); + auto *lastBlock = &loop.region().back(); + + // Move the blocks from the DoLoopOp between initBlock and endBlock + rewriter.inlineRegionBefore(loop.region(), endBlock); + + // Get loop values from the DoLoopOp + auto low = loop.lowerBound(); + auto high = loop.upperBound(); + assert(low && high && "must be a Value"); + auto step = loop.step(); + + // Initalization block + rewriter.setInsertionPointToEnd(initBlock); + auto diff = rewriter.create(loc, high, low); + auto distance = rewriter.create(loc, diff, step); + mlir::Value iters = + rewriter.create(loc, distance, step); + + if (forceLoopToExecuteOnce) { + auto zero = rewriter.create(loc, 0); + auto cond = + rewriter.create(loc, CmpIPredicate::sle, iters, zero); + auto one = rewriter.create(loc, 1); + iters = rewriter.create(loc, cond, one, iters); + } + + llvm::SmallVector loopOperands; + loopOperands.push_back(low); + auto operands = loop.getIterOperands(); + loopOperands.append(operands.begin(), operands.end()); + loopOperands.push_back(iters); + + rewriter.create(loc, conditionalBlock, loopOperands); + + // Last loop block + auto *terminator = lastBlock->getTerminator(); + rewriter.setInsertionPointToEnd(lastBlock); + auto iv = conditionalBlock->getArgument(0); + mlir::Value steppedIndex = rewriter.create(loc, iv, step); + assert(steppedIndex && "must be a Value"); + auto lastArg = conditionalBlock->getNumArguments() - 1; + auto itersLeft = conditionalBlock->getArgument(lastArg); + auto one = rewriter.create(loc, 1); + mlir::Value itersMinusOne = + rewriter.create(loc, itersLeft, one); + + llvm::SmallVector loopCarried; + loopCarried.push_back(steppedIndex); + auto begin = loop.finalValue() ? std::next(terminator->operand_begin()) + : terminator->operand_begin(); + loopCarried.append(begin, terminator->operand_end()); + loopCarried.push_back(itersMinusOne); + rewriter.create(loc, conditionalBlock, loopCarried); + rewriter.eraseOp(terminator); + + // Conditional block + rewriter.setInsertionPointToEnd(conditionalBlock); + auto zero = rewriter.create(loc, 0); + auto comparison = + rewriter.create(loc, CmpIPredicate::sgt, itersLeft, zero); + + rewriter.create(loc, comparison, firstBlock, + llvm::ArrayRef(), endBlock, + llvm::ArrayRef()); + + // The result of the loop operation is the values of the condition block + // arguments except the induction variable on the last iteration. + auto args = loop.finalValue() + ? conditionalBlock->getArguments() + : conditionalBlock->getArguments().drop_front(); + rewriter.replaceOp(loop, args.drop_back()); + return success(); + } +}; + +/// Convert `fir.if` to control-flow +class CfgIfConv : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(IfOp ifOp, mlir::PatternRewriter &rewriter) const override { + auto loc = ifOp.getLoc(); + + // Split the block containing the 'fir.if' into two parts. The part before + // will contain the condition, the part after will be the continuation + // point. + auto *condBlock = rewriter.getInsertionBlock(); + auto opPosition = rewriter.getInsertionPoint(); + auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); + mlir::Block *continueBlock; + if (ifOp.getNumResults() == 0) { + continueBlock = remainingOpsBlock; + } else { + continueBlock = + rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes()); + rewriter.create(loc, remainingOpsBlock); + } + + // Move blocks from the "then" region to the region containing 'fir.if', + // place it before the continuation block, and branch to it. + auto &ifOpRegion = ifOp.thenRegion(); + auto *ifOpBlock = &ifOpRegion.front(); + auto *ifOpTerminator = ifOpRegion.back().getTerminator(); + auto ifOpTerminatorOperands = ifOpTerminator->getOperands(); + rewriter.setInsertionPointToEnd(&ifOpRegion.back()); + rewriter.create(loc, continueBlock, ifOpTerminatorOperands); + rewriter.eraseOp(ifOpTerminator); + rewriter.inlineRegionBefore(ifOpRegion, continueBlock); + + // Move blocks from the "else" region (if present) to the region containing + // 'fir.if', place it before the continuation block and branch to it. It + // will be placed after the "then" regions. + auto *otherwiseBlock = continueBlock; + auto &otherwiseRegion = ifOp.elseRegion(); + if (!otherwiseRegion.empty()) { + otherwiseBlock = &otherwiseRegion.front(); + auto *otherwiseTerm = otherwiseRegion.back().getTerminator(); + auto otherwiseTermOperands = otherwiseTerm->getOperands(); + rewriter.setInsertionPointToEnd(&otherwiseRegion.back()); + rewriter.create(loc, continueBlock, + otherwiseTermOperands); + rewriter.eraseOp(otherwiseTerm); + rewriter.inlineRegionBefore(otherwiseRegion, continueBlock); + } + + rewriter.setInsertionPointToEnd(condBlock); + rewriter.create( + loc, ifOp.condition(), ifOpBlock, llvm::ArrayRef(), + otherwiseBlock, llvm::ArrayRef()); + rewriter.replaceOp(ifOp, continueBlock->getArguments()); + return success(); + } +}; + +/// Convert `fir.iter_while` to control-flow. +class CfgIterWhileConv : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(fir::IterWhileOp whileOp, + mlir::PatternRewriter &rewriter) const override { + auto loc = whileOp.getLoc(); + + // Start by splitting the block containing the 'fir.do_loop' into two parts. + // The part before will get the init code, the part after will be the end + // point. + auto *initBlock = rewriter.getInsertionBlock(); + auto initPosition = rewriter.getInsertionPoint(); + auto *endBlock = rewriter.splitBlock(initBlock, initPosition); + + // Use the first block of the loop body as the condition block since it is + // the block that has the induction variable and loop-carried values as + // arguments. Split out all operations from the first block into a new + // block. Move all body blocks from the loop body region to the region + // containing the loop. + auto *conditionBlock = &whileOp.region().front(); + auto *firstBodyBlock = + rewriter.splitBlock(conditionBlock, conditionBlock->begin()); + auto *lastBodyBlock = &whileOp.region().back(); + rewriter.inlineRegionBefore(whileOp.region(), endBlock); + auto iv = conditionBlock->getArgument(0); + auto iterateVar = conditionBlock->getArgument(1); + + // Append the induction variable stepping logic to the last body block and + // branch back to the condition block. Loop-carried values are taken from + // operands of the loop terminator. + auto *terminator = lastBodyBlock->getTerminator(); + rewriter.setInsertionPointToEnd(lastBodyBlock); + auto step = whileOp.step(); + mlir::Value stepped = rewriter.create(loc, iv, step); + assert(stepped && "must be a Value"); + + llvm::SmallVector loopCarried; + loopCarried.push_back(stepped); + auto begin = whileOp.finalValue() ? std::next(terminator->operand_begin()) + : terminator->operand_begin(); + loopCarried.append(begin, terminator->operand_end()); + rewriter.create(loc, conditionBlock, loopCarried); + rewriter.eraseOp(terminator); + + // Compute loop bounds before branching to the condition. + rewriter.setInsertionPointToEnd(initBlock); + auto lowerBound = whileOp.lowerBound(); + auto upperBound = whileOp.upperBound(); + assert(lowerBound && upperBound && "must be a Value"); + + // The initial values of loop-carried values is obtained from the operands + // of the loop operation. + llvm::SmallVector destOperands; + destOperands.push_back(lowerBound); + auto iterOperands = whileOp.getIterOperands(); + destOperands.append(iterOperands.begin(), iterOperands.end()); + rewriter.create(loc, conditionBlock, destOperands); + + // With the body block done, we can fill in the condition block. + rewriter.setInsertionPointToEnd(conditionBlock); + // The comparison depends on the sign of the step value. We fully expect + // this expression to be folded by the optimizer or LLVM. This expression + // is written this way so that `step == 0` always returns `false`. + auto zero = rewriter.create(loc, 0); + auto compl0 = + rewriter.create(loc, CmpIPredicate::slt, zero, step); + auto compl1 = + rewriter.create(loc, CmpIPredicate::sle, iv, upperBound); + auto compl2 = + rewriter.create(loc, CmpIPredicate::slt, step, zero); + auto compl3 = + rewriter.create(loc, CmpIPredicate::sle, upperBound, iv); + auto cmp0 = rewriter.create(loc, compl0, compl1); + auto cmp1 = rewriter.create(loc, compl2, compl3); + auto cmp2 = rewriter.create(loc, cmp0, cmp1); + // Remember to AND in the early-exit bool. + auto comparison = rewriter.create(loc, iterateVar, cmp2); + rewriter.create(loc, comparison, firstBodyBlock, + llvm::ArrayRef(), endBlock, + llvm::ArrayRef()); + // The result of the loop operation is the values of the condition block + // arguments except the induction variable on the last iteration. + auto args = whileOp.finalValue() + ? conditionBlock->getArguments() + : conditionBlock->getArguments().drop_front(); + rewriter.replaceOp(whileOp, args); + return success(); + } +}; + +/// Convert FIR structured control flow ops to CFG ops. +class CfgConversion : public CFGConversionBase { +public: + void runOnFunction() override { + if (disableCfgConversion) + return; + + auto *context = &getContext(); + mlir::OwningRewritePatternList patterns; + patterns.insert(context); + mlir::ConversionTarget target(*context); + target.addLegalDialect(); + + // apply the patterns + target.addIllegalOp(); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + if (mlir::failed(mlir::applyPartialConversion(getFunction(), target, + std::move(patterns)))) { + mlir::emitError(mlir::UnknownLoc::get(context), + "error in converting to CFG\n"); + signalPassFailure(); + } + } +}; +} // namespace + +bool fir::isAlwaysExecuteLoopBody() { return forceLoopToExecuteOnce; } +/// Convert FIR's structured control flow ops to CFG ops. This +/// conversion enables the `createLowerToCFGPass` to transform these to CFG +/// form. +std::unique_ptr fir::createFirToCfgPass() { + return std::make_unique(); +} Index: flang/test/Fir/loop01.fir =================================================================== --- /dev/null +++ flang/test/Fir/loop01.fir @@ -0,0 +1,181 @@ +// RUN: tco %s | FileCheck %s + +// CHECK-LABEL: @x +func @x(%lb : index, %ub : index, %step : index, %b : i1, %addr : !fir.ref) { + // CHECK: [[LOOP:[0-9]+]]: + // CHECK: %[[COND:.*]] = icmp sgt i64 %{{.*}}, 0 + // CHECK: br i1 %[[COND]] + fir.do_loop %iv = %lb to %ub step %step unordered { + // expect following conditional blocks to get fused + // CHECK: select i1 % + fir.if %b { + // CHECK: store i64 + fir.store %iv to %addr : !fir.ref + } else { + %zero = constant 0 : index + fir.store %zero to %addr : !fir.ref + } + // CHECK: br label %[[LOOP]] + } + // CHECK: ret void + return +} + +func private @f2() -> i1 + +// CHECK-LABEL: @x2 +func @x2(%lo : index, %up : index, %ok : i1) { + %c1 = constant 1 : index + // CHECK-DAG: %[[count:.*]] = phi i64 + // CHECK-DAG: %[[exit:.*]] = phi i1 + // CHECK: %[[cond:.*]] = icmp sle i64 %[[count]], % + // CHECK: %[[and:.*]] = and i1 %[[exit]], %[[cond]] + // CHECK: br i1 %[[and]] + %unused = fir.iterate_while (%i = %lo to %up step %c1) and (%ok1 = %ok) { + %ok2 = fir.call @f2() : () -> i1 + fir.result %ok2 : i1 + } + // CHECK: ret + return +} + +func private @f3(i16) + +// do_loop with an extra loop-carried value +// CHECK-LABEL: @x3 +func @x3(%lo : index, %up : index) -> i1 { + %c1 = constant 1 : index + %ok1 = constant true + // CHECK-DAG: %[[ok:.*]] = phi i1 + // CHECK-DAG: %[[count:.*]] = phi i64 + // CHECK: = icmp sgt i64 %[[count]], 0 + %ok2 = fir.do_loop %i = %lo to %up step %c1 iter_args(%j = %ok1) -> i1 { + %ok = fir.call @f2() : () -> i1 + fir.result %ok : i1 + // CHECK: = sub i64 %[[count]], 1 + } + // CHECK: ret i1 %[[ok]] + return %ok2 : i1 +} + +// iterate_while with an extra loop-carried value +// CHECK-LABEL: @y3 +func @y3(%lo : index, %up : index) -> i1 { + %c1 = constant 1 : index + %ok1 = constant true + // CHECK: %[[ok4:.*]] = call i1 @f2() + %ok4 = fir.call @f2() : () -> i1 + // CHECK-DAG: %[[count:.*]] = phi i64 + // CHECK-DAG: %[[ok3:.*]] = phi i1 {{.*}}[ true + // CHECK-DAG: %[[j:.*]] = phi i1 {{.*}}[ %[[ok4]] + // CHECK: %[[prev:.*]] = icmp sle i64 %[[count]], + // CHECK: = and i1 %[[ok3]], %[[prev]] + %ok2:2 = fir.iterate_while (%i = %lo to %up step %c1) and (%ok3 = %ok1) iter_args(%j = %ok4) -> i1 { + %ok = fir.call @f2() : () -> i1 + fir.result %ok3, %ok : i1, i1 + // CHECK: = add i64 %[[count]], 1 + } + // CHECK: %[[result:.*]] = and i1 %[[ok3]], %[[j]] + %andok = and %ok2#0, %ok2#1 : i1 + // CHECK: ret i1 %[[result]] + return %andok : i1 +} + +func private @f4(i32) -> i1 + +// do_loop that returns the final value of the induction +// CHECK-LABEL: @x4 +// CHECK-SAME: (i64 %[[lo:.*]], +func @x4(%lo : index, %up : index) -> index { + %c1 = constant 1 : index + // CHECK: %[[top:.*]] = add i64 + // CHECK-DAG: %[[i:.*]] = phi i64 {{.*}}[ %[[lo]], + // CHECK-DAG: %[[count:.*]] = phi i64 {{.*}}[ %[[top]], + // CHECK: icmp sgt i64 %[[count]], + %v = fir.do_loop %i = %lo to %up step %c1 -> index { + // CHECK: trunc i64 %[[i]] to i32 + %i1 = fir.convert %i : (index) -> i32 + // CHECK: call i1 @f4 + %ok = fir.call @f4(%i1) : (i32) -> i1 + fir.result %i : index + } + // CHECK: ret i64 %[[i]] + return %v : index +} + +// iterate_while that returns the final value of both inductions +// CHECK-LABEL: @y4 +func @y4(%lo : index, %up : index) -> index { + %c1 = constant 1 : index + %ok1 = constant true + // CHECK-DAG: %[[i:.*]] = phi i64 [ + // CHECK-DAG: %[[ok2:.*]] = phi i1 [ + // CHECK: icmp sle i64 %[[i]] + // CHECK: and i1 + %v:2 = fir.iterate_while (%i = %lo to %up step %c1) and (%ok2 = %ok1) -> (index, i1) { + %i1 = fir.convert %i : (index) -> i32 + // CHECK: call i1 @f4 + %ok = fir.call @f4(%i1) : (i32) -> i1 + fir.result %i, %ok : index, i1 + } + // CHECK: ret i64 %[[i]] + return %v#0 : index +} + +// do_loop that returns the final induction value +// and an extra loop-carried value +// CHECK-LABEL: @x5 +// CHECK-SAME: (i64 %[[lo:.*]], +func @x5(%lo : index, %up : index) -> index { + %c1 = constant 1 : index + // CHECK: %[[top:.*]] = add i64 + %s1 = constant 42 : i16 + // CHECK-DAG: %[[i:.*]] = phi i64 {{.*}}[ %[[lo]], + // CHECK-DAG: %[[count:.*]] = phi i64 {{.*}}[ %[[top]], + // CHECK-DAG: %[[s:.*]] = phi i16 + // CHECK: icmp sgt i64 %[[count]] + %v:2 = fir.do_loop %i = %lo to %up step %c1 iter_args(%s = %s1) -> (index, i16) { + // CHECK: call i1 @f2 + %ok = fir.call @f2() : () -> i1 + %s2 = fir.convert %ok : (i1) -> i16 + fir.result %i, %s2 : index, i16 + // CHECK: add i64 %[[i]], 1 + // CHECK: sub i64 %[[count]], 1 + } + // CHECK: call void @f3 + fir.call @f3(%v#1) : (i16) -> () + // CHECK: ret i64 %[[i]] + return %v#0 : index +} + +// iterate_while that returns the both induction values +// and an extra loop-carried value +// CHECK-LABEL: @y5 +// CHECK-SAME: (i64 %[[lo:.*]], +func @y5(%lo : index, %up : index) -> index { + %c1 = constant 1 : index + %s1 = constant 42 : i16 + %ok1 = constant true + // CHECK-DAG: %[[i:.*]] = phi i64 {{.*}}[ %[[lo]], + // CHECK-DAG: %[[ok2:.*]] = phi i1 {{.*}}[ true, + // CHECK-DAG: %[[s:.*]] = phi i16 {{.*}}[ 42, + // CHECK: icmp sle i64 %[[i]] + // CHECK: and i1 %[[ok2]] + %v:3 = fir.iterate_while (%i = %lo to %up step %c1) and (%ok2 = %ok1) iter_args(%s = %s1) -> (index, i1, i16) { + // CHECK: call i1 @f2 + %ok = fir.call @f2() : () -> i1 + %s2 = fir.convert %ok : (i1) -> i16 + fir.result %i, %ok, %s2 : index, i1, i16 + // CHECK: add i64 %[[i]], 1 + } + // CHECK: br i1 %[[ok2]], + fir.if %v#1 { + %arg = constant 0 : i32 + // CHECK: call i1 @f4 + %ok4 = fir.call @f4(%arg) : (i32) -> i1 + } + // CHECK: call void @f3(i16 %[[s]]) + fir.call @f3(%v#2) : (i16) -> () + // CHECK: ret i64 %[[i]] + return %v#0 : index +} Index: flang/test/Fir/loop02.fir =================================================================== --- /dev/null +++ flang/test/Fir/loop02.fir @@ -0,0 +1,17 @@ +// RUN: tco --always-execute-loop-body %s | FileCheck %s + +// CHECK-LABEL: @x +func @x(%addr : !fir.ref) { + %bound = constant 452 : index + %step = constant 1 : index + // CHECK: %[[phi:.*]] = phi i64 [{{.*}}], [ 1, + // CHECK: = icmp sgt i64 %[[phi]], 0 + fir.do_loop %iv = %bound to %bound step %step { + // CHECK: call void @y(i64* % + fir.call @y(%addr) : (!fir.ref) -> () + } + // CHECK: ret void + return +} + +func private @y(%addr : !fir.ref) Index: flang/tools/tco/tco.cpp =================================================================== --- flang/tools/tco/tco.cpp +++ flang/tools/tco/tco.cpp @@ -113,6 +113,7 @@ pm.addPass(mlir::createCSEPass()); // convert control flow to CFG form + pm.addNestedPass(fir::createFirToCfgPass()); pm.addPass(mlir::createLowerToCFGPass()); pm.addPass(mlir::createCanonicalizerPass());