diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h --- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h @@ -34,6 +34,22 @@ #define GET_OP_CLASSES #include "mlir/Dialect/OpenACC/OpenACCOps.h.inc" +#define ACC_DATA_ENTRY_OPS \ + mlir::acc::CopyinOp, mlir::acc::CreateOp, mlir::acc::PresentOp, \ + mlir::acc::NoCreateOp, mlir::acc::AttachOp, mlir::acc::DevicePtrOp, \ + mlir::acc::GetDevicePtrOp, mlir::acc::PrivateOp, \ + mlir::acc::FirstprivateOp, mlir::acc::UpdateDeviceOp, \ + mlir::acc::UseDeviceOp, mlir::acc::ReductionOp, \ + mlir::acc::DeclareDeviceResidentOp, mlir::acc::DeclareLinkOp +#define ACC_COMPUTE_CONSTRUCT_OPS \ + mlir::acc::ParallelOp, mlir::acc::KernelsOp, mlir::acc::SerialOp +#define ACC_DATA_CONSTRUCT_OPS \ + mlir::acc::DataOp, mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, \ + mlir::acc::UpdateOp, mlir::acc::HostDataOp, \ + mlir::acc::DeclareEnterOp mlir::acc::DeclareExitOp +#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS \ + ACC_COMPUTE_CONSTRUCT_OPS, ACC_DATA_CONSTRUCT_OPS + namespace mlir { namespace acc { @@ -48,6 +64,20 @@ /// combined and the final mapping value would be 5 (4 | 1). enum OpenACCExecMapping { NONE = 0, VECTOR = 1, WORKER = 2, GANG = 4 }; +/// Used to obtain the `varPtr` from a data entry operation. +/// Returns empty value if not a data entry operation. +mlir::Value getVarPtr(mlir::Operation *accDataEntryOp); + +/// Used to obtain the `dataClause` from a data entry operation. +/// Returns empty optional if not a data entry operation. +std::optional +getDataClause(mlir::Operation *accDataEntryOp); + +/// Used to obtain the attribute name for declare. +static constexpr StringLiteral getDeclareAttrName() { + return StringLiteral("acc.declare"); +} + } // namespace acc } // namespace mlir diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -116,6 +116,25 @@ def OpenACC_DataClauseAttr : EnumAttr; +class OpenACC_Attr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef { + let mnemonic = attrMnemonic; +} + +// Attribute to describe the declare data clause used on variable. +// Intended to be used at the variable creation site (on the global op or the +// corresponding allocation operation). This is used in conjunction with the +// declare operations (`acc.declare_enter` and `acc.declare_exit`) since those +// describe how the data action is performed. The attribute itself makes it +// easier to find out whether the variable is in a declare clause and what kind +// of clause it is. +def DeclareAttr : OpenACC_Attr<"Declare", "declare"> { + let parameters = (ins "DataClauseAttr":$dataClause); + let assemblyFormat = "`<` struct(params) `>`"; +} + // Used for data specification in data clauses (2.7.1). // Either (or both) extent and upperbound must be specified. def OpenACC_DataBoundsOp : OpenACC_Op<"bounds", diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -987,7 +987,7 @@ op->getLoc(), "at least one operand must appear on the declare operation"); - for (mlir::Value operand : operands) + for (mlir::Value operand : operands) { if (!mlir::isa( @@ -995,6 +995,32 @@ return op.emitError( "expect valid declare data entry operation or acc.getdeviceptr " "as defining op"); + + mlir::Value varPtr{getVarPtr(operand.getDefiningOp())}; + assert(varPtr && "declare operands can only be data entry operations which " + "must have varPtr"); + std::optional dataClauseOptional{ + getDataClause(operand.getDefiningOp())}; + assert(dataClauseOptional.has_value() && + "declare operands can only be data entry operations which must have " + "dataClause"); + + // If varPtr has no defining op - there is nothing to check further. + if (!varPtr.getDefiningOp()) + continue; + + // Check that the varPtr has a declare attribute. + auto declareAttribute{ + varPtr.getDefiningOp()->getAttr(mlir::acc::getDeclareAttrName())}; + if (!declareAttribute) + return op.emitError( + "expect declare attribute on variable in declare operation"); + if (llvm::cast(declareAttribute).getValue() != + dataClauseOptional.value()) + return op.emitError( + "expect matching declare attribute on variable in declare operation"); + } + return success(); } @@ -1106,3 +1132,26 @@ #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc" + +//===----------------------------------------------------------------------===// +// acc dialect utilities +//===----------------------------------------------------------------------===// + +mlir::Value mlir::acc::getVarPtr(mlir::Operation *accDataEntryOp) { + auto varPtr{llvm::TypeSwitch(accDataEntryOp) + .Case( + [&](auto entry) { return entry.getVarPtr(); }) + .Default([&](mlir::Operation *) { return mlir::Value(); })}; + return varPtr; +} + +std::optional +mlir::acc::getDataClause(mlir::Operation *accDataEntryOp) { + auto dataClause{ + llvm::TypeSwitch>( + accDataEntryOp) + .Case( + [&](auto entry) { return entry.getDataClause(); }) + .Default([&](mlir::Operation *) { return std::nullopt; })}; + return dataClause; +} diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -1605,20 +1605,20 @@ // ----- -llvm.mlir.global external @globalvar() : i32 { +llvm.mlir.global external @globalvar() { acc.declare = #acc } : i32 { %0 = llvm.mlir.constant(0 : i32) : i32 llvm.return %0 : i32 } acc.global_ctor @acc_constructor { - %0 = llvm.mlir.addressof @globalvar : !llvm.ptr + %0 = llvm.mlir.addressof @globalvar { acc.declare = #acc } : !llvm.ptr %1 = acc.create varPtr(%0 : !llvm.ptr) -> !llvm.ptr acc.declare_enter dataOperands(%1 : !llvm.ptr) acc.terminator } acc.global_dtor @acc_destructor { - %0 = llvm.mlir.addressof @globalvar : !llvm.ptr + %0 = llvm.mlir.addressof @globalvar { acc.declare = #acc } : !llvm.ptr %1 = acc.getdeviceptr varPtr(%0 : !llvm.ptr) -> !llvm.ptr { dataClause = #acc} acc.declare_exit dataOperands(%1 : !llvm.ptr) acc.delete accPtr(%1 : !llvm.ptr) @@ -1626,11 +1626,11 @@ } // CHECK-LABEL: acc.global_ctor @acc_constructor -// CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @globalvar : !llvm.ptr +// CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @globalvar {acc.declare = #acc} : !llvm.ptr // CHECK-NEXT: %[[CREATE:.*]] = acc.create varPtr(%[[ADDR]] : !llvm.ptr) -> !llvm.ptr // CHECK-NEXT: acc.declare_enter dataOperands(%[[CREATE]] : !llvm.ptr) // CHECK: acc.global_dtor @acc_destructor -// CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @globalvar : !llvm.ptr +// CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @globalvar {acc.declare = #acc} : !llvm.ptr // CHECK-NEXT: %[[DELETE:.*]] = acc.getdeviceptr varPtr(%[[ADDR]] : !llvm.ptr) -> !llvm.ptr {dataClause = #acc} // CHECK-NEXT: acc.declare_exit dataOperands(%[[DELETE]] : !llvm.ptr) // CHECK-NEXT: acc.delete accPtr(%[[DELETE]] : !llvm.ptr)