diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -23,6 +23,7 @@ #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" +#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h" #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -230,6 +230,17 @@ let dependentDialects = ["pdl_interp::PDLInterpDialect"]; } +//===----------------------------------------------------------------------===// +// SCFToOpenMP +//===----------------------------------------------------------------------===// + +def ConvertSCFToOpenMP : FunctionPass<"convert-scf-to-openmp"> { + let summary = "Convert SCF parallel loop to OpenMP parallel + workshare " + "constructs."; + let constructor = "mlir::createConvertSCFToOpenMPPass()"; + let dependentDialects = ["omp::OpenMPDialect"]; +} + //===----------------------------------------------------------------------===// // SCFToStandard //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h b/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h @@ -0,0 +1,23 @@ +//===- ConvertSCFToOpenMP.h - SCF to OpenMP pass entrypoint -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_SCFTOOPENMP_SCFTOOPENMP_H +#define MLIR_CONVERSION_SCFTOOPENMP_SCFTOOPENMP_H + +#include + +namespace mlir { +class FuncOp; +template +class OperationPass; + +std::unique_ptr> createConvertSCFToOpenMPPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_SCFTOOPENMP_SCFTOOPENMP_H diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -92,6 +92,9 @@ let regions = (region AnyRegion:$region); + let builders = [ + OpBuilderDAG<(ins CArg<"ArrayRef", "{}">:$attributes)> + ]; let parser = [{ return parseParallelOp(parser, result); }]; let printer = [{ return printParallelOp(p, *this); }]; let verifier = [{ return ::verifyParallelOp(*this); }]; @@ -175,6 +178,12 @@ Confined, [IntMinValue<0>]>:$ordered_val, OptionalAttr:$order_val); + let builders = [ + OpBuilderDAG<(ins "ValueRange":$lowerBound, "ValueRange":$upperBound, + "ValueRange":$step, + CArg<"ArrayRef", "{}">:$attributes)> + ]; + let regions = (region AnyRegion:$region); } diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -33,6 +33,10 @@ class NVVMDialect; } // end namespace NVVM +namespace omp { +class OpenMPDialect; +} // end namespace omp + namespace pdl_interp { class PDLInterpDialect; } // end namespace pdl_interp diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -0,0 +1,103 @@ +//===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert scf.parallel operations into OpenMP +// parallel loops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" +#include "../PassDetail.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { + +struct ParallelOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ParallelOp parallelOp, + PatternRewriter &rewriter) const override { + // TODO: add support for reductions when OpenMP loops have them. + if (parallelOp.getNumResults() != 0) + return rewriter.notifyMatchFailure( + parallelOp, + "OpenMP dialect does not yet support loops with reductions"); + + // Replace SCF yield with OpenMP yield. + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToEnd(parallelOp.getBody()); + assert(llvm::hasSingleElement(parallelOp.region()) && + "expected scf.parallel to have one block"); + rewriter.replaceOpWithNewOp( + parallelOp.getBody()->getTerminator(), ValueRange()); + } + + // Replace the loop. + auto loop = rewriter.create( + parallelOp.getLoc(), parallelOp.lowerBound(), parallelOp.upperBound(), + parallelOp.step()); + rewriter.inlineRegionBefore(parallelOp.region(), loop.region(), + loop.region().begin()); + rewriter.eraseOp(parallelOp); + return success(); + } +}; + +struct SCFToOpenMPPass : public ConvertSCFToOpenMPBase { + void insertOpenMPParallel(FuncOp func) { + SmallVector topLevelParallelOps; + func.walk([&topLevelParallelOps](scf::ParallelOp parallelOp) { + // Ignore ops that are already within OpenMP parallel construct. + if (!parallelOp.getParentOfType()) + topLevelParallelOps.push_back(parallelOp); + }); + + for (scf::ParallelOp parallelOp : topLevelParallelOps) { + OpBuilder builder(parallelOp); + auto omp = builder.create(parallelOp.getLoc()); + Block *block = builder.createBlock(&omp.getRegion()); + builder.create(parallelOp.getLoc()); + block->getOperations().splice( + block->begin(), + parallelOp.getOperation()->getBlock()->getOperations(), + parallelOp.getOperation()); + } + } + + LogicalResult applyPatterns(FuncOp func) { + ConversionTarget target(getContext()); + target.addIllegalOp(); + target.addDynamicallyLegalOp([](scf::YieldOp op) { + return !isa(op.getParentOp()); + }); + target.addLegalDialect(); + + OwningRewritePatternList patterns; + patterns.insert( + &getContext()); + FrozenRewritePatternList frozen(std::move(patterns)); + return applyPartialConversion(getFunction(), target, frozen); + } + + void runOnFunction() override { + insertOpenMPParallel(getFunction()); + if (failed(applyPatterns(getFunction()))) + signalPassFailure(); + } +}; + +} // end namespace + +std::unique_ptr> mlir::createConvertSCFToOpenMPPass() { + return std::make_unique(); +}