diff --git a/mlir/include/mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h b/mlir/include/mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h @@ -0,0 +1,26 @@ +//===- ConvertOpenACCToSCF.h - OpenACC conversion pass entrypoint ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_OPENACCTOSCF_CONVERTOPENACCTOSCF_H +#define MLIR_CONVERSION_OPENACCTOSCF_CONVERTOPENACCTOSCF_H + +namespace mlir { +class ModuleOp; +template +class OperationPass; +class RewritePatternSet; + +/// Collect the patterns to convert from the OpenACC dialect to OpenACC with +/// SCF dialect. +void populateOpenACCToSCFConversionPatterns(RewritePatternSet &patterns); + +/// Create a pass to convert the OpenACC dialect into the LLVMIR dialect. +std::unique_ptr> createConvertOpenACCToSCFPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_OPENACCTOSCF_CONVERTOPENACCTOSCF_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -23,6 +23,7 @@ #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" #include "mlir/Conversion/MathToLibm/MathToLibm.h" #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h" +#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -255,6 +255,16 @@ let dependentDialects = ["StandardOpsDialect", "vector::VectorDialect"]; } +//===----------------------------------------------------------------------===// +// OpenACCToSCF +//===----------------------------------------------------------------------===// + +def ConvertOpenACCToSCF : Pass<"convert-openacc-to-scf", "ModuleOp"> { + let summary = "Convert the OpenACC ops to OpenACC with SCF dialect"; + let constructor = "mlir::createConvertOpenACCToSCFPass()"; + let dependentDialects = ["scf::SCFDialect", "acc::OpenACCDialect"]; +} + //===----------------------------------------------------------------------===// // OpenACCToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -485,13 +485,13 @@ ``` }]; - let arguments = (ins Optional:$asyncOperand, + let arguments = (ins Optional:$ifCond, + Optional:$asyncOperand, Optional:$waitDevnum, Variadic:$waitOperands, UnitAttr:$async, UnitAttr:$wait, Variadic:$deviceTypeOperands, - Optional:$ifCond, Variadic:$hostOperands, Variadic:$deviceOperands, UnitAttr:$ifPresent); diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -12,6 +12,7 @@ add_subdirectory(LinalgToStandard) add_subdirectory(MathToLibm) add_subdirectory(OpenACCToLLVM) +add_subdirectory(OpenACCToSCF) add_subdirectory(OpenMPToLLVM) add_subdirectory(PDLToPDLInterp) add_subdirectory(SCFToGPU) diff --git a/mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt b/mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_conversion_library(MLIROpenACCToSCF + OpenACCToSCF.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/OpenACCToSCF + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIROpenACC + MLIRTransforms + MLIRSCF + ) diff --git a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp @@ -0,0 +1,117 @@ +//===- OpenACCToSCF.cpp - OpenACC condition to SCF if conversion ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../PassDetail.h" +#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +template +static void removeIfCondition(OpTy &op) { + // ifCondition is on position 0 for all standalone data operation. + op.getOperation()->eraseOperand(0); + auto attrName = op.getOperandSegmentSizeAttr(); + auto sizeAttr = + op.getOperation()->template getAttrOfType(attrName); + SmallVector sizes; + sizes.push_back(0); // remove ifCond in first position. + for (auto size : llvm::enumerate(sizeAttr.getIntValues())) + if (size.index() != 0) + sizes.push_back(size.value().getSExtValue()); + op.getOperation()->setAttr(attrName, + Builder(op.getContext()).getI32VectorAttr(sizes)); +} + +//===----------------------------------------------------------------------===// +// Conversion patterns +//===----------------------------------------------------------------------===// + +namespace { +template +class ExpandIfCondition : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &builder) const override { + // Early exit if there is no condition. + if (!op.ifCond()) + return success(); + + // Condition is not a constant. + if (!op.ifCond().template getDefiningOp()) { + auto ifOp = builder.create(op.getLoc(), TypeRange(), + op.ifCond(), false); + removeIfCondition(op); + ifOp.getThenBodyBuilder().clone(*op.getOperation()); + builder.eraseOp(op); + return success(); + } + + // Condition is constant. Remove op if always false. Remove condition if + // always right. + auto constOp = op.ifCond().template getDefiningOp(); + if (constOp.getValue().template cast().getInt()) { + removeIfCondition(op); + builder.replaceOpWithNewOp(op, TypeRange(), op.getOperands(), + op.getOperation()->getAttrs()); + } else { + builder.eraseOp(op); + } + return success(); + } +}; +} // namespace + +void mlir::populateOpenACCToSCFConversionPatterns(RewritePatternSet &patterns) { + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); +} + +namespace { +struct ConvertOpenACCToSCFPass + : public ConvertOpenACCToSCFBase { + void runOnOperation() override; +}; +} // namespace + +void ConvertOpenACCToSCFPass::runOnOperation() { + auto op = getOperation(); + auto *context = op.getContext(); + + RewritePatternSet patterns(context); + ConversionTarget target(*context); + populateOpenACCToSCFConversionPatterns(patterns); + + target.addLegalDialect(); + target.addLegalDialect(); + + target.addDynamicallyLegalOp( + [](acc::EnterDataOp op) { return !op.ifCond(); }); + + target.addDynamicallyLegalOp( + [](acc::ExitDataOp op) { return !op.ifCond(); }); + + target.addDynamicallyLegalOp( + [](acc::UpdateOp op) { return !op.ifCond(); }); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + signalPassFailure(); +} + +std::unique_ptr> mlir::createConvertOpenACCToSCFPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -19,6 +19,10 @@ template void registerDialect(DialectRegistry ®istry); +namespace acc { +class OpenACCDialect; +} // end namespace acc + namespace complex { class ComplexDialect; } // end namespace complex diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -8,9 +8,11 @@ #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Transforms/DialectConversion.h" using namespace mlir; using namespace acc; diff --git a/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir b/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir @@ -0,0 +1,98 @@ +// RUN: mlir-opt %s -convert-openacc-to-scf -split-input-file | FileCheck %s + +func @testenterdataop(%a: memref<10xf32>) -> () { + %ifCond = constant true + acc.enter_data if(%ifCond) create(%a: memref<10xf32>) + return +} + +// CHECK: acc.enter_data create(%{{.*}} : memref<10xf32>) + +// ----- + +func @testenterdataop(%a: memref<10xf32>) -> () { + %ifCond = constant false + acc.enter_data if(%ifCond) create(%a: memref<10xf32>) + return +} + +// CHECK: func @testenterdataop +// CHECK-NOT: acc.enter_data + +// ----- + +func @testenterdataop(%a: memref<10xf32>, %ifCond: i1) -> () { + acc.enter_data if(%ifCond) create(%a: memref<10xf32>) + return +} + +// CHECK: func @testenterdataop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1) +// CHECK: scf.if [[IFCOND]] { +// CHECK-NEXT: acc.enter_data create(%{{.*}} : memref<10xf32>) +// CHECK-NEXT: } + +// ----- + +func @testexitdataop(%a: memref<10xf32>) -> () { + %ifCond = constant true + acc.exit_data if(%ifCond) delete(%a: memref<10xf32>) + return +} + +// CHECK: acc.exit_data delete(%{{.*}} : memref<10xf32>) + +// ----- + +func @testexitdataop(%a: memref<10xf32>) -> () { + %ifCond = constant false + acc.exit_data if(%ifCond) delete(%a: memref<10xf32>) + return +} + +// CHECK: func @testexitdataop +// CHECK-NOT: acc.exit_data + +// ----- + +func @testexitdataop(%a: memref<10xf32>, %ifCond: i1) -> () { + acc.exit_data if(%ifCond) delete(%a: memref<10xf32>) + return +} + +// CHECK: func @testexitdataop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1) +// CHECK: scf.if [[IFCOND]] { +// CHECK-NEXT: acc.exit_data delete(%{{.*}} : memref<10xf32>) +// CHECK-NEXT: } + +// ----- + +func @testupdateop(%a: memref<10xf32>) -> () { + %ifCond = constant true + acc.update if(%ifCond) host(%a: memref<10xf32>) + return +} + +// CHECK: acc.update host(%{{.*}} : memref<10xf32>) + +// ----- + +func @testupdateop(%a: memref<10xf32>) -> () { + %ifCond = constant false + acc.update if(%ifCond) host(%a: memref<10xf32>) + return +} + +// CHECK: func @testupdateop +// CHECK-NOT: acc.update + +// ----- + +func @testupdateop(%a: memref<10xf32>, %ifCond: i1) -> () { + acc.update if(%ifCond) host(%a: memref<10xf32>) + return +} + +// CHECK: func @testupdateop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1) +// CHECK: scf.if [[IFCOND]] { +// CHECK-NEXT: acc.update host(%{{.*}} : memref<10xf32>) +// CHECK-NEXT: }