diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h --- a/flang/include/flang/Optimizer/Transforms/Passes.h +++ b/flang/include/flang/Optimizer/Transforms/Passes.h @@ -43,6 +43,7 @@ #define GEN_PASS_DECL_SIMPLIFYREGIONLITE #define GEN_PASS_DECL_ALGEBRAICSIMPLIFICATION #define GEN_PASS_DECL_POLYMORPHICOPCONVERSION +#define GEN_PASS_DECL_OPENACCDATAOPERANDCONVERSION #include "flang/Optimizer/Transforms/Passes.h.inc" std::unique_ptr createAbstractResultOnFuncOptPass(); @@ -70,6 +71,7 @@ std::unique_ptr createAlgebraicSimplificationPass(const mlir::GreedyRewriteConfig &config); std::unique_ptr createPolymorphicOpConversionPass(); +std::unique_ptr createOpenACCDataOperandConversionPass(); // declarative passes #define GEN_PASS_REGISTRATION diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -284,5 +284,14 @@ ]; } - +def OpenACCDataOperandConversion : Pass<"fir-openacc-data-operand-conversion", "::mlir::func::FuncOp"> { + let summary = "Convert the FIR operands in OpenACC ops to LLVM dialect"; + let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + let options = [ + Option<"useOpaquePointers", "use-opaque-pointers", "bool", + /*default=*/"true", "Generate LLVM IR using opaque pointers " + "instead of typed pointers">, + ]; +} + #endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -15,6 +15,7 @@ SimplifyIntrinsics.cpp AddDebugFoundation.cpp PolymorphicOpConversion.cpp + OpenACC/OpenACCDataOperandConversion.cpp DEPENDS FIRDialect @@ -22,6 +23,7 @@ LINK_LIBS FIRBuilder + FIRCodeGen FIRDialect FIRDialectSupport FIRSupport diff --git a/flang/lib/Optimizer/Transforms/OpenACC/OpenACCDataOperandConversion.cpp b/flang/lib/Optimizer/Transforms/OpenACC/OpenACCDataOperandConversion.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/OpenACC/OpenACCDataOperandConversion.cpp @@ -0,0 +1,180 @@ +//===- OpenACCDataOperandConversion.cpp - OpenACC data operand conversion -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +namespace fir { +#define GEN_PASS_DEF_OPENACCDATAOPERANDCONVERSION +#include "flang/Optimizer/Transforms/Passes.h.inc" +} // namespace fir + +#define DEBUG_TYPE "flang-openacc-conversion" +#include "../CodeGen/TypeConverter.h" + +using namespace fir; +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Conversion patterns +//===----------------------------------------------------------------------===// + +namespace { + +template +class LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &builder) const override { + Location loc = op.getLoc(); + fir::LLVMTypeConverter &converter = + *static_cast(this->getTypeConverter()); + + unsigned numDataOperands = op.getNumDataOperands(); + + // Keep the non data operands without modification. + auto nonDataOperands = adaptor.getOperands().take_front( + adaptor.getOperands().size() - numDataOperands); + SmallVector convertedOperands; + convertedOperands.append(nonDataOperands.begin(), nonDataOperands.end()); + + // Go over the data operand and legalize them for translation. + for (unsigned idx = 0; idx < numDataOperands; ++idx) { + Value originalDataOperand = op.getDataOperand(idx); + if (auto refTy = + originalDataOperand.getType().dyn_cast()) { + if (refTy.getEleTy().isa()) + return builder.notifyMatchFailure(op, "BaseBoxType not supported"); + mlir::Type convertedType = + converter.convertType(refTy).cast(); + mlir::Value castedOperand = + builder + .create(loc, convertedType, + originalDataOperand) + .getResult(0); + convertedOperands.push_back(castedOperand); + } else { + // Type not supported. + return builder.notifyMatchFailure(op, "expecting a reference type"); + } + } + + builder.replaceOpWithNewOp(op, TypeRange(), convertedOperands, + op.getOperation()->getAttrs()); + + return success(); + } +}; +} // namespace + +namespace { +struct OpenACCDataOperandConversion + : public fir::impl::OpenACCDataOperandConversionBase< + OpenACCDataOperandConversion> { + using Base::Base; + + void runOnOperation() override; +}; +} // namespace + +void OpenACCDataOperandConversion::runOnOperation() { + auto op = getOperation(); + auto *context = op.getContext(); + + // Convert to OpenACC operations with LLVM IR dialect + RewritePatternSet patterns(context); + LowerToLLVMOptions options(context); + options.useOpaquePointers = useOpaquePointers; + fir::LLVMTypeConverter converter( + op.getOperation()->getParentOfType(), true); + patterns.add>(converter); + patterns.add>(converter); + patterns.add>(converter); + patterns.add>(converter); + patterns.add>(converter); + + ConversionTarget target(*context); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalOp(); + + auto allDataOperandsAreConverted = [](ValueRange operands) { + for (Value operand : operands) { + if (!operand.getType().isa()) + return false; + } + return true; + }; + + target.addDynamicallyLegalOp( + [allDataOperandsAreConverted](acc::DataOp op) { + return allDataOperandsAreConverted(op.getCopyOperands()) && + allDataOperandsAreConverted(op.getCopyinOperands()) && + allDataOperandsAreConverted(op.getCopyinReadonlyOperands()) && + allDataOperandsAreConverted(op.getCopyoutOperands()) && + allDataOperandsAreConverted(op.getCopyoutZeroOperands()) && + allDataOperandsAreConverted(op.getCreateOperands()) && + allDataOperandsAreConverted(op.getCreateZeroOperands()) && + allDataOperandsAreConverted(op.getNoCreateOperands()) && + allDataOperandsAreConverted(op.getPresentOperands()) && + allDataOperandsAreConverted(op.getDeviceptrOperands()) && + allDataOperandsAreConverted(op.getAttachOperands()); + }); + + target.addDynamicallyLegalOp( + [allDataOperandsAreConverted](acc::EnterDataOp op) { + return allDataOperandsAreConverted(op.getCopyinOperands()) && + allDataOperandsAreConverted(op.getCreateOperands()) && + allDataOperandsAreConverted(op.getCreateZeroOperands()) && + allDataOperandsAreConverted(op.getAttachOperands()); + }); + + target.addDynamicallyLegalOp( + [allDataOperandsAreConverted](acc::ExitDataOp op) { + return allDataOperandsAreConverted(op.getCopyoutOperands()) && + allDataOperandsAreConverted(op.getDeleteOperands()) && + allDataOperandsAreConverted(op.getDetachOperands()); + }); + + target.addDynamicallyLegalOp( + [allDataOperandsAreConverted](acc::ParallelOp op) { + return allDataOperandsAreConverted(op.getReductionOperands()) && + allDataOperandsAreConverted(op.getCopyOperands()) && + allDataOperandsAreConverted(op.getCopyinOperands()) && + allDataOperandsAreConverted(op.getCopyinReadonlyOperands()) && + allDataOperandsAreConverted(op.getCopyoutOperands()) && + allDataOperandsAreConverted(op.getCopyoutZeroOperands()) && + allDataOperandsAreConverted(op.getCreateOperands()) && + allDataOperandsAreConverted(op.getCreateZeroOperands()) && + allDataOperandsAreConverted(op.getNoCreateOperands()) && + allDataOperandsAreConverted(op.getPresentOperands()) && + allDataOperandsAreConverted(op.getDevicePtrOperands()) && + allDataOperandsAreConverted(op.getAttachOperands()) && + allDataOperandsAreConverted(op.getGangPrivateOperands()) && + allDataOperandsAreConverted(op.getGangFirstPrivateOperands()); + }); + + target.addDynamicallyLegalOp( + [allDataOperandsAreConverted](acc::UpdateOp op) { + return allDataOperandsAreConverted(op.getHostOperands()) && + allDataOperandsAreConverted(op.getDeviceOperands()); + }); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + signalPassFailure(); +} diff --git a/flang/test/Transforms/OpenACC/convert-data-operands-to-llvmir.fir b/flang/test/Transforms/OpenACC/convert-data-operands-to-llvmir.fir new file mode 100644 --- /dev/null +++ b/flang/test/Transforms/OpenACC/convert-data-operands-to-llvmir.fir @@ -0,0 +1,84 @@ +// RUN: fir-opt -fir-openacc-data-operand-conversion='use-opaque-pointers=1' -split-input-file %s | FileCheck %s + +func.func @_QQsub1() attributes {fir.bindc_name = "arr"} { + %0 = fir.address_of(@_QFEa) : !fir.ref> + acc.data copy(%0 : !fir.ref>) { + acc.terminator + } + return +} + +// CHECK-LABEL: func.func @_QQsub1() attributes {fir.bindc_name = "arr"} { +// CHECK: %[[ADDR:.*]] = fir.address_of(@_QFEa) : !fir.ref> +// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ADDR]] : !fir.ref> to !llvm.ptr> +// CHECK: acc.data copy(%[[CAST]] : !llvm.ptr>) + +// ----- + +func.func @_QQsub_enter_exit() attributes {fir.bindc_name = "a"} { + %0 = fir.address_of(@_QFEa) : !fir.ref> + acc.enter_data copyin(%0 : !fir.ref>) + acc.exit_data copyout(%0 : !fir.ref>) + return +} + +// CHECK-LABEL: func.func @_QQsub_enter_exit() attributes {fir.bindc_name = "a"} { +// CHECK: %[[ADDR:.*]] = fir.address_of(@_QFEa) : !fir.ref> +// CHECK: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[ADDR]] : !fir.ref> to !llvm.ptr> +// CHECK: acc.enter_data copyin(%[[CAST0]] : !llvm.ptr>) +// CHECK: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[ADDR]] : !fir.ref> to !llvm.ptr> +// CHECK: acc.exit_data copyout(%[[CAST1]] : !llvm.ptr>) + +// ----- + +func.func @_QQsub_update() attributes {fir.bindc_name = "a"} { + %0 = fir.address_of(@_QFEa) : !fir.ref> + acc.update device(%0 : !fir.ref>) + return +} + +// CHECK-LABEL: func.func @_QQsub_update() attributes {fir.bindc_name = "a"} { +// CHECK: %[[ADDR:.*]] = fir.address_of(@_QFEa) : !fir.ref> +// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ADDR]] : !fir.ref> to !llvm.ptr> +// CHECK: acc.update device(%[[CAST]] : !llvm.ptr>) + +// ----- + +func.func @_QQsub_parallel() attributes {fir.bindc_name = "test"} { + %0 = fir.address_of(@_QFEa) : !fir.ref> + %1 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFEi"} + acc.parallel copyin(%0: !fir.ref>) { + acc.loop { + %c1_i32 = arith.constant 1 : i32 + %2 = fir.convert %c1_i32 : (i32) -> index + %c10_i32 = arith.constant 10 : i32 + %3 = fir.convert %c10_i32 : (i32) -> index + %c1 = arith.constant 1 : index + %4 = fir.convert %2 : (index) -> i32 + %5:2 = fir.do_loop %arg0 = %2 to %3 step %c1 iter_args(%arg1 = %4) -> (index, i32) { + fir.store %arg1 to %1 : !fir.ref + %6 = fir.load %1 : !fir.ref + %7 = fir.convert %6 : (i32) -> f32 + %c10_i64 = arith.constant 10 : i64 + %c1_i64 = arith.constant 1 : i64 + %8 = arith.subi %c10_i64, %c1_i64 : i64 + %9 = fir.coordinate_of %0, %8 : (!fir.ref>, i64) -> !fir.ref + fir.store %7 to %9 : !fir.ref + %10 = arith.addi %arg0, %c1 : index + %11 = fir.convert %c1 : (index) -> i32 + %12 = fir.load %1 : !fir.ref + %13 = arith.addi %12, %11 : i32 + fir.result %10, %13 : index, i32 + } + fir.store %5#1 to %1 : !fir.ref + acc.yield + } + acc.yield + } + return +} + +// CHECK-LABEL: func.func @_QQsub_parallel() attributes {fir.bindc_name = "test"} { +// CHECK: %[[ADDR:.*]] = fir.address_of(@_QFEa) : !fir.ref> +// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ADDR]] : !fir.ref> to !llvm.ptr> +// CHECK: acc.parallel copyin(%[[CAST]]: !llvm.ptr>) {