diff --git a/mlir/include/mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h b/mlir/include/mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h @@ -0,0 +1,28 @@ +//===- ConvertOpenACCToSCF.h - OpenACC conversion pass entrypoint ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_OPENACCTOSCF_CONVERTOPENACCTOSCF_H +#define MLIR_CONVERSION_OPENACCTOSCF_CONVERTOPENACCTOSCF_H + +#include + +namespace mlir { +class ModuleOp; +template +class OperationPass; +class RewritePatternSet; + +/// Collect the patterns to convert from the OpenACC dialect to OpenACC with +/// SCF dialect. +void populateOpenACCToSCFConversionPatterns(RewritePatternSet &patterns); + +/// Create a pass to convert the OpenACC dialect into the LLVMIR dialect. +std::unique_ptr> createConvertOpenACCToSCFPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_OPENACCTOSCF_CONVERTOPENACCTOSCF_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -23,6 +23,7 @@ #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" #include "mlir/Conversion/MathToLibm/MathToLibm.h" #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h" +#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -255,6 +255,16 @@ let dependentDialects = ["StandardOpsDialect", "vector::VectorDialect"]; } +//===----------------------------------------------------------------------===// +// OpenACCToSCF +//===----------------------------------------------------------------------===// + +def ConvertOpenACCToSCF : Pass<"convert-openacc-to-scf", "ModuleOp"> { + let summary = "Convert the OpenACC ops to OpenACC with SCF dialect"; + let constructor = "mlir::createConvertOpenACCToSCFPass()"; + let dependentDialects = ["scf::SCFDialect", "acc::OpenACCDialect"]; +} + //===----------------------------------------------------------------------===// // OpenACCToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -489,13 +489,13 @@ ``` }]; - let arguments = (ins Optional:$asyncOperand, + let arguments = (ins Optional:$ifCond, + Optional:$asyncOperand, Optional:$waitDevnum, Variadic:$waitOperands, UnitAttr:$async, UnitAttr:$wait, Variadic:$deviceTypeOperands, - Optional:$ifCond, Variadic:$hostOperands, Variadic:$deviceOperands, UnitAttr:$ifPresent); diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -12,6 +12,7 @@ add_subdirectory(LinalgToStandard) add_subdirectory(MathToLibm) add_subdirectory(OpenACCToLLVM) +add_subdirectory(OpenACCToSCF) add_subdirectory(OpenMPToLLVM) add_subdirectory(PDLToPDLInterp) add_subdirectory(SCFToGPU) diff --git a/mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt b/mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_conversion_library(MLIROpenACCToSCF + OpenACCToSCF.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/OpenACCToSCF + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIROpenACC + MLIRTransforms + MLIRSCF + ) diff --git a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp @@ -0,0 +1,88 @@ +//===- OpenACCToSCF.cpp - OpenACC condition to SCF if conversion ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../PassDetail.h" +#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Conversion patterns +//===----------------------------------------------------------------------===// + +namespace { +/// Pattern to transform the `ifCond` on operation without region into a scf.if +/// and move the operation into the `then` region. +template +class ExpandIfCondition : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Early exit if there is no condition. + if (!op.ifCond()) + return success(); + + // Condition is not a constant. + if (!op.ifCond().template getDefiningOp()) { + auto ifOp = rewriter.create(op.getLoc(), TypeRange(), + op.ifCond(), false); + rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); }); + ifOp.getThenBodyBuilder().clone(*op.getOperation()); + rewriter.eraseOp(op); + } + + return success(); + } +}; +} // namespace + +void mlir::populateOpenACCToSCFConversionPatterns(RewritePatternSet &patterns) { + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); +} + +namespace { +struct ConvertOpenACCToSCFPass + : public ConvertOpenACCToSCFBase { + void runOnOperation() override; +}; +} // namespace + +void ConvertOpenACCToSCFPass::runOnOperation() { + auto op = getOperation(); + auto *context = op.getContext(); + + RewritePatternSet patterns(context); + ConversionTarget target(*context); + populateOpenACCToSCFConversionPatterns(patterns); + + target.addLegalDialect(); + target.addLegalDialect(); + + target.addDynamicallyLegalOp( + [](acc::EnterDataOp op) { return !op.ifCond(); }); + + target.addDynamicallyLegalOp( + [](acc::ExitDataOp op) { return !op.ifCond(); }); + + target.addDynamicallyLegalOp( + [](acc::UpdateOp op) { return !op.ifCond(); }); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + signalPassFailure(); +} + +std::unique_ptr> mlir::createConvertOpenACCToSCFPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -19,6 +19,10 @@ template void registerDialect(DialectRegistry ®istry); +namespace acc { +class OpenACCDialect; +} // end namespace acc + namespace complex { class ComplexDialect; } // end namespace complex diff --git a/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir b/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt %s -convert-openacc-to-scf -split-input-file | FileCheck %s + +func @testenterdataop(%a: memref<10xf32>, %ifCond: i1) -> () { + acc.enter_data if(%ifCond) create(%a: memref<10xf32>) + return +} + +// CHECK: func @testenterdataop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1) +// CHECK: scf.if [[IFCOND]] { +// CHECK-NEXT: acc.enter_data create(%{{.*}} : memref<10xf32>) +// CHECK-NEXT: } + +// ----- + +func @testexitdataop(%a: memref<10xf32>, %ifCond: i1) -> () { + acc.exit_data if(%ifCond) delete(%a: memref<10xf32>) + return +} + +// CHECK: func @testexitdataop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1) +// CHECK: scf.if [[IFCOND]] { +// CHECK-NEXT: acc.exit_data delete(%{{.*}} : memref<10xf32>) +// CHECK-NEXT: } + +// ----- + +func @testupdateop(%a: memref<10xf32>, %ifCond: i1) -> () { + acc.update if(%ifCond) host(%a: memref<10xf32>) + return +} + +// CHECK: func @testupdateop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1) +// CHECK: scf.if [[IFCOND]] { +// CHECK-NEXT: acc.update host(%{{.*}} : memref<10xf32>) +// CHECK-NEXT: }