diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -23,6 +23,7 @@ #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" +#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h" #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -230,6 +230,17 @@ let dependentDialects = ["pdl_interp::PDLInterpDialect"]; } +//===----------------------------------------------------------------------===// +// SCFToOpenMP +//===----------------------------------------------------------------------===// + +def ConvertSCFToOpenMP : FunctionPass<"convert-scf-to-openmp"> { + let summary = "Convert SCF parallel loop to OpenMP parallel + workshare " + "constructs."; + let constructor = "mlir::createConvertSCFToOpenMPPass()"; + let dependentDialects = ["omp::OpenMPDialect"]; +} + //===----------------------------------------------------------------------===// // SCFToStandard //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h b/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h @@ -0,0 +1,23 @@ +//===- ConvertSCFToOpenMP.h - SCF to OpenMP pass entrypoint -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_SCFTOOPENMP_SCFTOOPENMP_H +#define MLIR_CONVERSION_SCFTOOPENMP_SCFTOOPENMP_H + +#include + +namespace mlir { +class FuncOp; +template +class OperationPass; + +std::unique_ptr> createConvertSCFToOpenMPPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_SCFTOOPENMP_SCFTOOPENMP_H diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -92,6 +92,9 @@ let regions = (region AnyRegion:$region); + let builders = [ + OpBuilderDAG<(ins CArg<"ArrayRef", "{}">:$attributes)> + ]; let parser = [{ return parseParallelOp(parser, result); }]; let printer = [{ return printParallelOp(p, *this); }]; let verifier = [{ return ::verifyParallelOp(*this); }]; @@ -175,6 +178,12 @@ Confined, [IntMinValue<0>]>:$ordered_val, OptionalAttr:$order_val); + let builders = [ + OpBuilderDAG<(ins "ValueRange":$lowerBound, "ValueRange":$upperBound, + "ValueRange":$step, + CArg<"ArrayRef", "{}">:$attributes)> + ]; + let regions = (region AnyRegion:$region); } diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -12,6 +12,7 @@ add_subdirectory(OpenMPToLLVM) add_subdirectory(PDLToPDLInterp) add_subdirectory(SCFToGPU) +add_subdirectory(SCFToOpenMP) add_subdirectory(SCFToSPIRV) add_subdirectory(SCFToStandard) add_subdirectory(ShapeToStandard) diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -33,6 +33,10 @@ class NVVMDialect; } // end namespace NVVM +namespace omp { +class OpenMPDialect; +} // end namespace omp + namespace pdl_interp { class PDLInterpDialect; } // end namespace pdl_interp diff --git a/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt b/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_conversion_library(MLIRSCFToOpenMP + SCFToOpenMP.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToStandard + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIROpenMP + MLIRSCF + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -0,0 +1,111 @@ +//===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert scf.parallel operations into OpenMP +// parallel loops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" +#include "../PassDetail.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { + +/// Converts SCF parallel operation into an OpenMP workshare loop construct. +struct ParallelOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ParallelOp parallelOp, + PatternRewriter &rewriter) const override { + // TODO: add support for reductions when OpenMP loops have them. + if (parallelOp.getNumResults() != 0) + return rewriter.notifyMatchFailure( + parallelOp, + "OpenMP dialect does not yet support loops with reductions"); + + // Replace SCF yield with OpenMP yield. + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToEnd(parallelOp.getBody()); + assert(llvm::hasSingleElement(parallelOp.region()) && + "expected scf.parallel to have one block"); + rewriter.replaceOpWithNewOp( + parallelOp.getBody()->getTerminator(), ValueRange()); + } + + // Replace the loop. + auto loop = rewriter.create( + parallelOp.getLoc(), parallelOp.lowerBound(), parallelOp.upperBound(), + parallelOp.step()); + rewriter.inlineRegionBefore(parallelOp.region(), loop.region(), + loop.region().begin()); + rewriter.eraseOp(parallelOp); + return success(); + } +}; + +/// Inserts OpenMP "parallel" operations around top-level SCF "parallel" +/// operations in the given function. This is implemented as a direct IR +/// modification rather than as a conversion pattern because it does not +/// modify the top-level operation it matches, which is a requirement for +/// rewrite patterns. +static void insertOpenMPParallel(FuncOp func) { + // Collect top-level SCF "parallel" ops. + SmallVector topLevelParallelOps; + func.walk([&topLevelParallelOps](scf::ParallelOp parallelOp) { + // Ignore ops that are already within OpenMP parallel construct. + if (!parallelOp.getParentOfType()) + topLevelParallelOps.push_back(parallelOp); + }); + + // Wrap SCF ops into OpenMP "parallel" ops. + for (scf::ParallelOp parallelOp : topLevelParallelOps) { + OpBuilder builder(parallelOp); + auto omp = builder.create(parallelOp.getLoc()); + Block *block = builder.createBlock(&omp.getRegion()); + builder.create(parallelOp.getLoc()); + block->getOperations().splice( + block->begin(), parallelOp.getOperation()->getBlock()->getOperations(), + parallelOp.getOperation()); + } +} + +/// Applies the conversion patterns in the given function. +static LogicalResult applyPatterns(FuncOp func) { + ConversionTarget target(*func.getContext()); + target.addIllegalOp(); + target.addDynamicallyLegalOp( + [](scf::YieldOp op) { return !isa(op.getParentOp()); }); + target.addLegalDialect(); + + OwningRewritePatternList patterns; + patterns.insert(func.getContext()); + FrozenRewritePatternList frozen(std::move(patterns)); + return applyPartialConversion(func, target, frozen); +} + +/// A pass converting SCF operations to OpenMP operations. +struct SCFToOpenMPPass : public ConvertSCFToOpenMPBase { + /// Pass entry point. + void runOnFunction() override { + insertOpenMPParallel(getFunction()); + if (failed(applyPatterns(getFunction()))) + signalPassFailure(); + } +}; + +} // end namespace + +std::unique_ptr> mlir::createConvertSCFToOpenMPPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -37,6 +37,17 @@ // ParallelOp //===----------------------------------------------------------------------===// +void ParallelOp::build(OpBuilder &builder, OperationState &state, + ArrayRef attributes) { + ParallelOp::build( + builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, + /*default_val=*/nullptr, /*private_vars=*/ValueRange(), + /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(), + /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(), + /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr); + state.addAttributes(attributes); +} + /// Parse a list of operands with types. /// /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` @@ -362,5 +373,22 @@ return success(); } +//===----------------------------------------------------------------------===// +// WsLoopOp +//===----------------------------------------------------------------------===// + +void WsLoopOp::build(OpBuilder &builder, OperationState &state, + ValueRange lowerBound, ValueRange upperBound, + ValueRange step, ArrayRef attributes) { + build(builder, state, TypeRange(), lowerBound, upperBound, step, + /*private_vars=*/ValueRange(), + /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(), + /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), + /*schedule_val=*/nullptr, /*schedule_chunk_var=*/nullptr, + /*collapse_val=*/nullptr, + /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr); + state.addAttributes(attributes); +} + #define GET_OP_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir @@ -0,0 +1,65 @@ +// RUN: mlir-opt -convert-scf-to-openmp %s | FileCheck %s + +// CHECK-LABEL: @parallel +func @parallel(%arg0: index, %arg1: index, %arg2: index, + %arg3: index, %arg4: index, %arg5: index) { + // CHECK: omp.parallel { + // CHECK: "omp.wsloop"({{.*}}) ( { + scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) { + // CHECK: test.payload + "test.payload"(%i, %j) : (index, index) -> () + // CHECK: omp.yield + // CHECK: } + } + // CHECK: omp.terminator + // CHECK: } + return +} + +// CHECK-LABEL: @nested_loops +func @nested_loops(%arg0: index, %arg1: index, %arg2: index, + %arg3: index, %arg4: index, %arg5: index) { + // CHECK: omp.parallel { + // CHECK: "omp.wsloop"({{.*}}) ( { + // CHECK-NOT: omp.parallel + scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) { + // CHECK: "omp.wsloop"({{.*}}) ( { + scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) { + // CHECK: test.payload + "test.payload"(%i, %j) : (index, index) -> () + // CHECK: omp.yield + // CHECK: } + } + // CHECK: omp.yield + // CHECK: } + } + // CHECK: omp.terminator + // CHECK: } + return +} + +func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index, + %arg3: index, %arg4: index, %arg5: index) { + // CHECK: omp.parallel { + // CHECK: "omp.wsloop"({{.*}}) ( { + scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) { + // CHECK: test.payload1 + "test.payload1"(%i) : (index) -> () + // CHECK: omp.yield + // CHECK: } + } + // CHECK: omp.terminator + // CHECK: } + + // CHECK: omp.parallel { + // CHECK: "omp.wsloop"({{.*}}) ( { + scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) { + // CHECK: test.payload2 + "test.payload2"(%j) : (index) -> () + // CHECK: omp.yield + // CHECK: } + } + // CHECK: omp.terminator + // CHECK: } + return +}