diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h --- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h @@ -74,6 +74,11 @@ std::optional getDataClause(mlir::Operation *accDataEntryOp); +/// Used to find out whether data operation is implicit. +/// Returns false if not a data operation or if it is a data operation without +/// implicit flag. +bool getImplicitFlag(mlir::Operation *accDataEntryOp); + /// Used to obtain the attribute name for declare. static constexpr StringLiteral getDeclareAttrName() { return StringLiteral("acc.declare"); diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -131,8 +131,13 @@ // easier to find out whether the variable is in a declare clause and what kind // of clause it is. def DeclareAttr : OpenACC_Attr<"Declare", "declare"> { - let parameters = (ins "DataClauseAttr":$dataClause); + let parameters = (ins "DataClauseAttr":$dataClause, + DefaultValuedParameter<"bool", "false">:$implicit); let assemblyFormat = "`<` struct(params) `>`"; + let builders = [AttrBuilder<(ins "DataClauseAttr":$dataClause), [{ + return $_get($_ctxt, dataClause, /*implicit=*/false); + }]> + ]; } // Attribute to attach functions that perform the pre/post allocation actions or diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -1178,11 +1178,21 @@ if (!declareAttribute) return op.emitError( "expect declare attribute on variable in declare operation"); - if (mlir::cast(declareAttribute) - .getDataClause() - .getValue() != dataClauseOptional.value()) + + auto declAttr = mlir::cast(declareAttribute); + if (declAttr.getDataClause().getValue() != dataClauseOptional.value()) return op.emitError( "expect matching declare attribute on variable in declare operation"); + + // If the variable is marked with implicit attribute, the matching declare + // data action must also be marked implicit. The reverse is not checked + // since implicit data action may be inserted to do actions like updating + // device copy, in which case the variable is not necessarily implicitly + // declare'd. + if (declAttr.getImplicit() && + declAttr.getImplicit() != acc::getImplicitFlag(operand.getDefiningOp())) + return op.emitError( + "implicitness must match between declare op and flag on variable"); } return success(); @@ -1397,3 +1407,11 @@ .Default([&](mlir::Operation *) { return std::nullopt; })}; return dataClause; } + +bool mlir::acc::getImplicitFlag(mlir::Operation *accDataEntryOp) { + auto implicit{llvm::TypeSwitch(accDataEntryOp) + .Case( + [&](auto entry) { return entry.getImplicit(); }) + .Default([&](mlir::Operation *) { return false; })}; + return implicit; +} diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -1684,12 +1684,17 @@ %numGangs = arith.constant 10 : i64 %numWorkers = arith.constant 10 : i64 + %c20 = arith.constant 20 : i32 + %alloc = llvm.alloca %c20 x i32 { acc.declare = #acc.declare } : (i32) -> !llvm.ptr + %createlocal = acc.create varPtr(%alloc : !llvm.ptr) -> !llvm.ptr {implicit = true} + %pa = acc.present varPtr(%a : memref<10x10xf32>) -> memref<10x10xf32> %pb = acc.present varPtr(%b : memref<10x10xf32>) -> memref<10x10xf32> %pc = acc.present varPtr(%c : memref<10xf32>) -> memref<10xf32> %pd = acc.present varPtr(%d : memref<10xf32>) -> memref<10xf32> - acc.declare dataOperands(%pa, %pb, %pc, %pd: memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) { + acc.declare dataOperands(%pa, %pb, %pc, %pd, %createlocal: memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>, !llvm.ptr) { } + return }