diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp @@ -36,6 +36,11 @@ static constexpr uint64_t createFlag = 0; /// 1 = to/copyin static constexpr uint64_t copyinFlag = 1; +/// 2 = from/copyout +static constexpr uint64_t copyoutFlag = 2; +/// 8 = delete +static constexpr uint64_t deleteFlag = 8; + /// Default value for the device id static constexpr int64_t defaultDevice = -1; @@ -57,10 +62,10 @@ } /// Create the location struct from the operation location information. -static llvm::Value *createSourceLocationInfo(acc::EnterDataOp &op, - OpenACCIRBuilder &builder) { - auto loc = op.getLoc(); - auto funcOp = op.getOperation()->getParentOfType(); +static llvm::Value *createSourceLocationInfo(OpenACCIRBuilder &builder, + Operation *op) { + auto loc = op->getLoc(); + auto funcOp = op->getParentOfType(); StringRef funcName = funcOp ? funcOp.getName() : "unknown"; llvm::Constant *locStr = createSourceLocStrFromLocation(loc, builder, funcName); @@ -81,10 +86,16 @@ /// Return the runtime function used to lower the given operation. static llvm::Function *getAssociatedFunction(OpenACCIRBuilder &builder, - Operation &op) { - if (isa(op)) - return builder.getOrCreateRuntimeFunctionPtr( - llvm::omp::OMPRTL___tgt_target_data_begin_mapper); + Operation *op) { + return llvm::TypeSwitch(op) + .Case([&](acc::EnterDataOp enterDataOp) { + return builder.getOrCreateRuntimeFunctionPtr( + llvm::omp::OMPRTL___tgt_target_data_begin_mapper); + }) + .Case([&](acc::ExitDataOp exitDataOp) { + return builder.getOrCreateRuntimeFunctionPtr( + llvm::omp::OMPRTL___tgt_target_data_end_mapper); + }); llvm_unreachable("Unknown OpenACC operation"); } @@ -105,7 +116,7 @@ /// to populate the future functions arguments. static LogicalResult processOperands(llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation, Operation &op, + LLVM::ModuleTranslation &moduleTranslation, Operation *op, ValueRange operands, unsigned totalNbOperand, uint64_t operandFlag, SmallVector &flags, SmallVector &names, unsigned &index, @@ -137,7 +148,7 @@ dataPtr = dataValue; dataSize = getSizeInBytes(builder, dataValue); } else { - return op.emitOpError() + return op->emitOpError() << "Data operand must be legalized before translation." << "Unsupported type: " << data.getType(); } @@ -171,28 +182,81 @@ return success(); } +/// Process data operands from acc::EnterDataOp +static LogicalResult +processDataOperands(llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + acc::EnterDataOp op, SmallVector &flags, + SmallVector &names, unsigned &index, + llvm::AllocaInst *argsBase, llvm::AllocaInst *args, + llvm::AllocaInst *argSizes) { + // TODO add `create_zero` and `attach` operands + + // Create operands are handled as `alloc` call. + if (failed(processOperands(builder, moduleTranslation, op, + op.createOperands(), op.getNumDataOperands(), + createFlag, flags, names, index, argsBase, args, + argSizes))) + return failure(); + + // Copyin operands are handled as `to` call. + if (failed(processOperands(builder, moduleTranslation, op, + op.copyinOperands(), op.getNumDataOperands(), + copyinFlag, flags, names, index, argsBase, args, + argSizes))) + return failure(); + + return success(); +} + +/// Process data operands from acc::ExitDataOp +static LogicalResult +processDataOperands(llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + acc::ExitDataOp op, SmallVector &flags, + SmallVector &names, unsigned &index, + llvm::AllocaInst *argsBase, llvm::AllocaInst *args, + llvm::AllocaInst *argSizes) { + // TODO add `detach` operands + + // Delete operands are handled as `delete` call. + if (failed(processOperands(builder, moduleTranslation, op, + op.deleteOperands(), op.getNumDataOperands(), + deleteFlag, flags, names, index, argsBase, args, + argSizes))) + return failure(); + + // Copyout operands are handled as `from` call. + if (failed(processOperands(builder, moduleTranslation, op, + op.copyoutOperands(), op.getNumDataOperands(), + copyoutFlag, flags, names, index, argsBase, args, + argSizes))) + return failure(); + + return success(); +} + //===----------------------------------------------------------------------===// // Conversion functions //===----------------------------------------------------------------------===// -/// Converts an OpenACC enter_data operartion into LLVM IR. +/// Converts an OpenACC standalone data operation into LLVM IR. +template static LogicalResult -convertEnterDataOp(Operation &op, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { - auto enterDataOp = cast(op); - auto enclosingFuncOp = op.getParentOfType(); +convertStandaloneDataOp(OpTy &op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto enclosingFuncOp = + op.getOperation()->template getParentOfType(); llvm::Function *enclosingFunction = moduleTranslation.lookupFunction(enclosingFuncOp.getName()); OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder(); - auto *srcLocInfo = createSourceLocationInfo(enterDataOp, *accBuilder); + auto *srcLocInfo = createSourceLocationInfo(*accBuilder, op); auto *mapperFunc = getAssociatedFunction(*accBuilder, op); // Number of arguments in the enter_data operation. - // TODO include create_zero and attach operands. - unsigned totalNbOperand = - enterDataOp.createOperands().size() + enterDataOp.copyinOperands().size(); + unsigned totalNbOperand = op.getNumDataOperands(); // TODO could be moved to OpenXXIRBuilder? llvm::LLVMContext &ctx = builder.getContext(); @@ -214,18 +278,8 @@ SmallVector names; unsigned index = 0; - // Create operands are handled as `alloc` call. - if (failed(processOperands(builder, moduleTranslation, op, - enterDataOp.createOperands(), totalNbOperand, - createFlag, flags, names, index, argsBase, args, - argSizes))) - return failure(); - - // Copyin operands are handled as `to` call. - if (failed(processOperands(builder, moduleTranslation, op, - enterDataOp.copyinOperands(), totalNbOperand, - copyinFlag, flags, names, index, argsBase, args, - argSizes))) + if (failed(processDataOperands(builder, moduleTranslation, op, flags, names, + index, argsBase, args, argSizes))) return failure(); llvm::GlobalVariable *maptypes = @@ -282,8 +336,13 @@ LLVM::ModuleTranslation &moduleTranslation) const { return llvm::TypeSwitch(op) - .Case([&](acc::EnterDataOp) { - return convertEnterDataOp(*op, builder, moduleTranslation); + .Case([&](acc::EnterDataOp enterDataOp) { + return convertStandaloneDataOp(enterDataOp, builder, + moduleTranslation); + }) + .Case([&](acc::ExitDataOp exitDataOp) { + return convertStandaloneDataOp(exitDataOp, builder, + moduleTranslation); }) .Default([&](Operation *op) { return op->emitError("unsupported OpenACC operation: ") diff --git a/mlir/test/Target/LLVMIR/openacc-llvm.mlir b/mlir/test/Target/LLVMIR/openacc-llvm.mlir --- a/mlir/test/Target/LLVMIR/openacc-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openacc-llvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s +// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s llvm.func @testenterdataop(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: !llvm.ptr) { %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> @@ -63,3 +63,63 @@ // CHECK: call void @__tgt_target_data_begin_mapper(%struct.ident_t* [[LOCGLOBAL]], i64 -1, i32 2, i8** [[ARGBASE_ALLOCA_GEP]], i8** [[ARG_ALLOCA_GEP]], i64* [[SIZE_ALLOCA_GEP]], i64* getelementptr inbounds ([{{[0-9]*}} x i64], [{{[0-9]*}} x i64]* [[MAPTYPES]], i32 0, i32 0), i8** getelementptr inbounds ([{{[0-9]*}} x i8*], [{{[0-9]*}} x i8*]* [[MAPNAMES]], i32 0, i32 0), i8** null) // CHECK: declare void @__tgt_target_data_begin_mapper(%struct.ident_t*, i64, i32, i8**, i8**, i64*, i64*, i8**, i8**) #0 + +// ----- + + +llvm.func @testexitdataop(%arg0: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, %arg1: !llvm.ptr) { + %0 = llvm.mlir.constant(10 : index) : i64 + %1 = llvm.mlir.null : !llvm.ptr + %2 = llvm.getelementptr %1[%0] : (!llvm.ptr, i64) -> !llvm.ptr + %3 = llvm.ptrtoint %2 : !llvm.ptr to i64 + %4 = llvm.extractvalue %arg0[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %5 = llvm.mlir.undef : !llvm.struct<"openacc_data", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)> + %6 = llvm.insertvalue %arg0, %5[0] : !llvm.struct<"openacc_data", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)> + %7 = llvm.insertvalue %4, %6[1] : !llvm.struct<"openacc_data", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)> + %8 = llvm.insertvalue %3, %7[2] : !llvm.struct<"openacc_data", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)> + acc.exit_data copyout(%arg1 : !llvm.ptr) delete(%8 : !llvm.struct<"openacc_data", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)>) + llvm.return +} + +// CHECK: %struct.ident_t = type { i32, i32, i32, i32, i8* } + +// CHECK: [[LOCSTR:@.*]] = private unnamed_addr constant [{{[0-9]*}} x i8] c";{{.*}};testexitdataop;{{[0-9]*}};{{[0-9]*}};;\00", align 1 +// CHECK: [[LOCGLOBAL:@.*]] = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 0, i8* getelementptr inbounds ([{{[0-9]*}} x i8], [{{[0-9]*}} x i8]* [[LOCSTR]], i32 0, i32 0) }, align 8 +// CHECK: [[MAPNAME1:@.*]] = private unnamed_addr constant [{{[0-9]*}} x i8] c";{{.*}};unknown;{{[0-9]*}};{{[0-9]*}};;\00", align 1 +// CHECK: [[MAPNAME2:@.*]] = private unnamed_addr constant [{{[0-9]*}} x i8] c";{{.*}};unknown;{{[0-9]*}};{{[0-9]*}};;\00", align 1 +// CHECK: [[MAPTYPES:@.*]] = private unnamed_addr constant [{{[0-9]*}} x i64] [i64 8, i64 2] +// CHECK: [[MAPNAMES:@.*]] = private constant [{{[0-9]*}} x i8*] [i8* getelementptr inbounds ([{{[0-9]*}} x i8], [{{[0-9]*}} x i8]* [[MAPNAME1]], i32 0, i32 0), i8* getelementptr inbounds ([{{[0-9]*}} x i8], [{{[0-9]*}} x i8]* [[MAPNAME2]], i32 0, i32 0)] + +// CHECK: define void @testexitdataop({ float*, float*, i64, [1 x i64], [1 x i64] } %{{.*}}, float* [[SIMPLEPTR:%.*]]) +// CHECK: [[ARGBASE_ALLOCA:%.*]] = alloca [{{[0-9]*}} x i8*], align 8 +// CHECK: [[ARG_ALLOCA:%.*]] = alloca [{{[0-9]*}} x i8*], align 8 +// CHECK: [[SIZE_ALLOCA:%.*]] = alloca [{{[0-9]*}} x i64], align 8 + +// CHECK: [[ARGBASE:%.*]] = extractvalue %openacc_data %{{.*}}, 0 +// CHECK: [[ARG:%.*]] = extractvalue %openacc_data %{{.*}}, 1 +// CHECK: [[ARGSIZE:%.*]] = extractvalue %openacc_data %{{.*}}, 2 +// CHECK: [[ARGBASEGEP:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[ARGBASE_ALLOCA]], i32 0, i32 0 +// CHECK: [[ARGBASEGEPCAST:%.*]] = bitcast i8** [[ARGBASEGEP]] to { float*, float*, i64, [1 x i64], [1 x i64] }* +// CHECK: store { float*, float*, i64, [1 x i64], [1 x i64] } [[ARGBASE]], { float*, float*, i64, [1 x i64], [1 x i64] }* [[ARGBASEGEPCAST]], align 8 +// CHECK: [[ARGGEP:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[ARG_ALLOCA]], i32 0, i32 0 +// CHECK: [[ARGGEPCAST:%.*]] = bitcast i8** [[ARGGEP]] to float** +// CHECK: store float* [[ARG]], float** [[ARGGEPCAST]], align 8 +// CHECK: [[SIZEGEP:%.*]] = getelementptr inbounds [2 x i64], [2 x i64]* [[SIZE_ALLOCA]], i32 0, i32 0 +// CHECK: store i64 [[ARGSIZE]], i64* [[SIZEGEP]], align 4 + +// CHECK: [[ARGBASEGEP:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[ARGBASE_ALLOCA]], i32 0, i32 1 +// CHECK: [[ARGBASEGEPCAST:%.*]] = bitcast i8** [[ARGBASEGEP]] to float** +// CHECK: store float* [[SIMPLEPTR]], float** [[ARGBASEGEPCAST]], align 8 +// CHECK: [[ARGGEP:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[ARG_ALLOCA]], i32 0, i32 1 +// CHECK: [[ARGGEPCAST:%.*]] = bitcast i8** [[ARGGEP]] to float** +// CHECK: store float* [[SIMPLEPTR]], float** [[ARGGEPCAST]], align 8 +// CHECK: [[SIZEGEP:%.*]] = getelementptr inbounds [2 x i64], [2 x i64]* [[SIZE_ALLOCA]], i32 0, i32 1 +// CHECK: store i64 ptrtoint (i1** getelementptr (i1*, i1** null, i32 1) to i64), i64* [[SIZEGEP]], align 4 + +// CHECK: [[ARGBASE_ALLOCA_GEP:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[ARGBASE_ALLOCA]], i32 0, i32 0 +// CHECK: [[ARG_ALLOCA_GEP:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[ARG_ALLOCA]], i32 0, i32 0 +// CHECK: [[SIZE_ALLOCA_GEP:%.*]] = getelementptr inbounds [2 x i64], [2 x i64]* [[SIZE_ALLOCA]], i32 0, i32 0 + +// CHECK: call void @__tgt_target_data_end_mapper(%struct.ident_t* [[LOCGLOBAL]], i64 -1, i32 2, i8** [[ARGBASE_ALLOCA_GEP]], i8** [[ARG_ALLOCA_GEP]], i64* [[SIZE_ALLOCA_GEP]], i64* getelementptr inbounds ([{{[0-9]*}} x i64], [{{[0-9]*}} x i64]* [[MAPTYPES]], i32 0, i32 0), i8** getelementptr inbounds ([{{[0-9]*}} x i8*], [{{[0-9]*}} x i8*]* [[MAPNAMES]], i32 0, i32 0), i8** null) + +// CHECK: declare void @__tgt_target_data_end_mapper(%struct.ident_t*, i64, i32, i8**, i8**, i64*, i64*, i8**, i8**) #0