diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(GPU) add_subdirectory(Linalg) add_subdirectory(LLVMIR) +add_subdirectory(OpenACC) add_subdirectory(OpenMP) add_subdirectory(Quant) add_subdirectory(SCF) diff --git a/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt @@ -0,0 +1,9 @@ +set(LLVM_TARGET_DEFINITIONS OpenACCOps.td) +mlir_tablegen(OpenACCOpsDialect.h.inc -gen-dialect-decls -dialect=acc) +mlir_tablegen(OpenACCOps.h.inc -gen-op-decls) +mlir_tablegen(OpenACCOps.cpp.inc -gen-op-defs) +mlir_tablegen(OpenACCOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpenACCOpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(OpenACCOps -gen-dialect-doc OpenACCDialect Dialects/) +add_public_tablegen_target(MLIROpenACCOpsIncGen) + diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h @@ -0,0 +1,44 @@ +//===- OpenACC.h - MLIR OpenACC Dialect -------------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ============================================================================ +// +// This file declares the OpenACC dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_OPENACC_OPENACC_H_ +#define MLIR_DIALECT_OPENACC_OPENACC_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.h.inc" + +namespace mlir { +namespace acc { + +#define GET_OP_CLASSES +#include "mlir/Dialect/OpenACC/OpenACCOps.h.inc" + +#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.h.inc" + +// Enumeration used to encode the execution mapping on a loop construct. +// They refer directly to the OpenACC 3.0 standard: +// 2.9.2. gang +// 2.9.3. worker +// 2.9.4. vector +// 2.9.5. seq +// +// Value can be combined bitwise to reflect the mapping applied to the +// construct. e.g. `acc.loop gang vector`, the `gang` and `vector` could be +// combined and the final mapping value would be 5 (4 & 1). +enum OpenACCExecMapping { NONE = 0, VECTOR = 1, WORKER = 2, GANG = 4, SEQ = 8 }; + +} // end namespace acc +} // end namespace mlir + +#endif // MLIR_DIALECT_OPENACC_OPENACC_H_ diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -0,0 +1,253 @@ +//===- OpenACC.td - OpenACC operation definitions ----------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ============================================================================= +// +// Defines MLIR OpenACC operations. +// +//===----------------------------------------------------------------------===// + +#ifndef OPENACC_OPS +#define OPENACC_OPS + +include "mlir/IR/OpBase.td" + +def OpenACC_Dialect : Dialect { + let name = "acc"; + + let description = [{ + An OpenACC dialect for MLIR. + + This dialect models the construct from the OpenACC 3.0 directive language. + }]; + + let cppNamespace = "acc"; +} + +// Base class for OpenACC dialect ops. +class OpenACC_Op<string mnemonic, list<OpTrait> traits = []> : + Op<OpenACC_Dialect, mnemonic, traits> { + // For every OpenACC op, there needs to be a: + // * void print(OpAsmPrinter &p, ${C++ class of Op} op) + // * LogicalResult verify(${C++ class of Op} op) + // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, + // OperationState &result) + // functions. + let printer = [{ return ::print(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; +} + +// Reduction operation enumeration +def OpenACC_ReductionOpAdd : StrEnumAttrCase<"redop_add">; +def OpenACC_ReductionOpMul : StrEnumAttrCase<"redop_mul">; +def OpenACC_ReductionOpMax : StrEnumAttrCase<"redop_max">; +def OpenACC_ReductionOpMin : StrEnumAttrCase<"redop_min">; +def OpenACC_ReductionOpAnd : StrEnumAttrCase<"redop_and">; +def OpenACC_ReductionOpOr : StrEnumAttrCase<"redop_or">; +def OpenACC_ReductionOpXor : StrEnumAttrCase<"redop_xor">; +def OpenACC_ReductionOpLogEqv : StrEnumAttrCase<"redop_leqv">; +def OpenACC_ReductionOpLogNeqv : StrEnumAttrCase<"redop_lneqv">; +def OpenACC_ReductionOpLogAnd : StrEnumAttrCase<"redop_land">; +def OpenACC_ReductionOpLogOr : StrEnumAttrCase<"redop_lor">; + +def OpenACC_ReductionOpAttr : StrEnumAttr<"ReductionOpAttr", + "built-in reduction operations supported by OpenACC", + [OpenACC_ReductionOpAdd, OpenACC_ReductionOpMul, OpenACC_ReductionOpMax, + OpenACC_ReductionOpMin, OpenACC_ReductionOpAnd, OpenACC_ReductionOpOr, + OpenACC_ReductionOpXor, OpenACC_ReductionOpLogEqv, + OpenACC_ReductionOpLogNeqv, OpenACC_ReductionOpLogAnd, + OpenACC_ReductionOpLogOr + ]> { + let cppNamespace = "::mlir::acc"; +} + +//===----------------------------------------------------------------------===// +// 2.5.1 parallel Construct +//===----------------------------------------------------------------------===// + +def OpenACC_ParallelOp : OpenACC_Op<"parallel", + [SingleBlockImplicitTerminator<"YieldOp">, + AttrSizedOperandSegments]> { + let summary = "parallel construct"; + let description = [{ + The "acc.parallel" operation represents a parallel construct block. It has + one region to be executued in parallel on the current device. + + acc.parallel { + ... // body + } + }]; + + let verifier = ?; + let arguments = (ins Optional<Index>:$async, + Variadic<Index>:$waitOperands, + Optional<Index>:$numGangs, + Optional<Index>:$numWorkers, + Optional<Index>:$vectorLength, + Optional<I1>:$ifCond, + Optional<I1>:$selfCond, + OptionalAttr<OpenACC_ReductionOpAttr>:$reductionOp, + Variadic<AnyType>:$reductionOperands, + Variadic<AnyType>:$copyOperands, + Variadic<AnyType>:$copyinOperands, + Variadic<AnyType>:$copyoutOperands, + Variadic<AnyType>:$createOperands, + Variadic<AnyType>:$noCreateOperands, + Variadic<AnyType>:$presentOperands, + Variadic<AnyType>:$devicePtrOperands, + Variadic<AnyType>:$attachOperands, + Variadic<AnyType>:$gangPrivateOperands, + Variadic<AnyType>:$gangFirstPrivateOperands); + + let regions = (region AnyRegion:$region); + + let extraClassDeclaration = [{ + static StringRef getAsyncKeyword() { return "async"; } + static StringRef getWaitKeyword() { return "wait"; } + static StringRef getNumGangsKeyword() { return "num_gangs"; } + static StringRef getNumWorkersKeyword() { return "num_workers"; } + static StringRef getVectorLengthKeyword() { return "vector_length"; } + static StringRef getIfKeyword() { return "if"; } + static StringRef getSelfKeyword() { return "self"; } + static StringRef getReductionKeyword() { return "reduction"; } + static StringRef getCopyKeyword() { return "copy"; } + static StringRef getCopyinKeyword() { return "copyin"; } + static StringRef getCopyoutKeyword() { return "copyout"; } + static StringRef getCreateKeyword() { return "create"; } + static StringRef getNoCreateKeyword() { return "no_create"; } + static StringRef getPresentKeyword() { return "present"; } + static StringRef getDevicePtrKeyword() { return "deviceptr"; } + static StringRef getAttachKeyword() { return "attach"; } + static StringRef getPrivateKeyword() { return "private"; } + static StringRef getFirstPrivateKeyword() { return "firstprivate"; } + }]; +} + +//===----------------------------------------------------------------------===// +// 2.6.5 data Construct +//===----------------------------------------------------------------------===// + +def OpenACC_DataOp : OpenACC_Op<"data", + [SingleBlockImplicitTerminator<"DataEndOp">, + AttrSizedOperandSegments]> { + let summary = "data construct"; + + let description = [{ + The "acc.data" operation represents a data construct. It defines vars to + be allocated in the current device memory for the duration of the region, + whether data should be copied from local memory to the current device + memory upon region entry , and copied from device memory to local memory + upon region exit. + + acc.data { + ... // body + } + }]; + + let verifier = ?; + + let arguments = (ins Variadic<AnyType>:$presentOperands, + Variadic<AnyType>:$copyOperands, + Variadic<AnyType>:$copyinOperands, + Variadic<AnyType>:$copyoutOperands, + Variadic<AnyType>:$createOperands, + Variadic<AnyType>:$noCreateOperands, + Variadic<AnyType>:$deleteOperands, + Variadic<AnyType>:$attachOperands, + Variadic<AnyType>:$detachOperands); + + let regions = (region AnyRegion:$region); + + let extraClassDeclaration = [{ + static StringRef getAttachKeyword() { return "attach"; } + static StringRef getDeleteKeyword() { return "delete"; } + static StringRef getDetachKeyword() { return "detach"; } + static StringRef getCopyinKeyword() { return "copyin"; } + static StringRef getCopyKeyword() { return "copy"; } + static StringRef getCopyoutKeyword() { return "copyout"; } + static StringRef getCreateKeyword() { return "create"; } + static StringRef getNoCreateKeyword() { return "no_create"; } + static StringRef getPresentKeyword() { return "present"; } + }]; +} + +def OpenACC_DataEndOp : OpenACC_Op<"_data_end", [Terminator]> { + let summary = "terminator for OpenACC data regions"; + + let description = [{ + A terminator operation for regions that appear in the body of OpenACC data + operation. Data construct regions are not expected to return any value so + the terminator takes no operands. The terminator op returns control to the + enclosing op. + }]; + + let verifier = ?; + + let assemblyFormat = "attr-dict"; +} + +//===----------------------------------------------------------------------===// +// 2.9 loop Construct +//===----------------------------------------------------------------------===// + +def OpenACC_LoopOp : OpenACC_Op<"loop", + [SingleBlockImplicitTerminator<"YieldOp">, + AttrSizedOperandSegments]> { + let summary = "loop construct"; + + let description = [{ + The "acc.loop" operation represent the acc loop construct + }]; + + let verifier = ?; + + let arguments = (ins OptionalAttr<I64Attr>:$collapse, + Variadic<AnyType>:$privateOperands, + OptionalAttr<OpenACC_ReductionOpAttr>:$reductionOp, + Variadic<AnyType>:$reductionOperands); + + let results = (outs Variadic<AnyType>:$results); + + let regions = (region AnyRegion:$region); + + let extraClassDeclaration = [{ + static StringRef getCollapseAttrName() { return "collapse"; } + static StringRef getExecutionMappingAttrName() { return "exec_mapping"; } + static StringRef getGangAttrName() { return "gang"; } + static StringRef getSeqAttrName() { return "seq"; } + static StringRef getVectorAttrName() { return "vector"; } + static StringRef getWorkerAttrName() { return "worker"; } + static StringRef getPrivateKeyword() { return "private"; } + static StringRef getReductionKeyword() { return "reduction"; } + }]; +} + +// Yield operation for the acc.loop and acc.parallel operations. +def OpenACC_YieldOp : OpenACC_Op<"yield", [Terminator, + ParentOneOf<["ParallelOp, LoopOp"]>]> { + let summary = "Acc yield and termination operation"; + + let description = [{ + `acc.yield` is a special terminator operation for block inside regions in + acc ops (parallel and loop).It returns values to the immediately enclosing + acc op. + }]; + + let arguments = (ins Variadic<AnyType>:$operands); + + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &result", + [{ /* nothing to do */ }]> + ]; + + let verifier = ?; + + let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; +} + +#endif // OPENACC_OPS + diff --git a/mlir/include/mlir/IR/DialectSymbolRegistry.def b/mlir/include/mlir/IR/DialectSymbolRegistry.def --- a/mlir/include/mlir/IR/DialectSymbolRegistry.def +++ b/mlir/include/mlir/IR/DialectSymbolRegistry.def @@ -20,6 +20,7 @@ DEFINE_SYM_KIND_RANGE(IREE) // IREE stands for IR Execution Engine DEFINE_SYM_KIND_RANGE(LINALG) // Linear Algebra Dialect DEFINE_SYM_KIND_RANGE(FIR) // Flang Fortran IR Dialect +DEFINE_SYM_KIND_RANGE(OPENACC) // OpenACC IR Dialect DEFINE_SYM_KIND_RANGE(OPENMP) // OpenMP IR Dialect DEFINE_SYM_KIND_RANGE(TOY) // Toy language (tutorial) Dialect DEFINE_SYM_KIND_RANGE(SPIRV) // SPIR-V dialect diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -22,6 +22,7 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/SCF/SCF.h" @@ -38,6 +39,7 @@ // all the possible dialects to be made available to the context automatically. inline void registerAllDialects() { static bool init_once = []() { + registerDialect<acc::OpenACCDialect>(); registerDialect<AffineDialect>(); registerDialect<avx512::AVX512Dialect>(); registerDialect<gpu::GPUDialect>(); diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(GPU) add_subdirectory(Linalg) add_subdirectory(LLVMIR) +add_subdirectory(OpenACC) add_subdirectory(OpenMP) add_subdirectory(Quant) add_subdirectory(SCF) diff --git a/mlir/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIROpenACC + IR/OpenACC.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC + + DEPENDS + MLIROpenACCOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + ) + diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -0,0 +1,576 @@ +//===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ============================================================================= + +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Value.h" + +#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc" + +using namespace mlir; +using namespace acc; + +//===----------------------------------------------------------------------===// +// OpenACC operations +//===----------------------------------------------------------------------===// + +OpenACCDialect::OpenACCDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context) { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" + >(); + allowsUnknownOperations(); +} + +template <typename StructureOp> +static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, + unsigned nRegions = 1) { + + llvm::SmallVector<Region *, 2> regions; + for (unsigned i = 0; i < nRegions; ++i) + regions.push_back(state.addRegion()); + + for (Region *region : regions) { + if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{})) + return failure(); + StructureOp::ensureTerminator(*region, parser.getBuilder(), state.location); + } + + return success(); +} + +static ParseResult +parseOperandList(OpAsmParser &parser, StringRef keyword, + SmallVectorImpl<OpAsmParser::OperandType> &args, + SmallVectorImpl<Type> &argTypes, OperationState &result) { + if (failed(parser.parseOptionalKeyword(keyword))) + return success(); + + if (failed(parser.parseLParen())) + return failure(); + + // Exit early if the list is empty. + if (succeeded(parser.parseOptionalRParen())) + return success(); + + do { + OpAsmParser::OperandType arg; + Type type; + + if (parser.parseRegionArgument(arg) || parser.parseColonType(type)) + return failure(); + + args.push_back(arg); + argTypes.push_back(type); + } while (succeeded(parser.parseOptionalComma())); + + if (failed(parser.parseRParen())) + return failure(); + + if (parser.resolveOperands(args, argTypes, parser.getCurrentLocation(), + result.operands)) + return failure(); + + return success(); +} + +static void printOperandList(Operation::operand_range operands, + StringRef listName, OpAsmPrinter &printer) { + if (operands.size() > 0) { + printer << " " << listName << "("; + for (auto operand : operands) + printer << operand << ": " << operand.getType(); + printer << ")"; + } +} + +static ParseResult parseOptionalOperand(OpAsmParser &parser, StringRef keyword, + OpAsmParser::OperandType &operand, + Type &type, bool &hasOptional, + OperationState &result) { + hasOptional = false; + if (succeeded(parser.parseOptionalKeyword(keyword))) { + hasOptional = true; + if (parser.parseLParen() || parser.parseOperand(operand) || + parser.resolveOperand(operand, type, result.operands) || + parser.parseRParen()) + return failure(); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// ParallelOp +//===----------------------------------------------------------------------===// + +/// Parse acc.parallel operation +/// operation := `acc.parallel` `num_gangs(value)?` `num_workers(value)?` +/// `vector_length(value)?` `private(value)?` +/// `private(value-list)?` +/// `firstprivate(value-list)`? +/// region attr-dict? +static ParseResult parseParallelOp(OpAsmParser &parser, + OperationState &result) { + auto &builder = parser.getBuilder(); + SmallVector<OpAsmParser::OperandType, 8> privateOperands, + firstprivateOperands, createOperands, copyOperands, copyinOperands, + copyoutOperands, noCreateOperands, presentOperands, devicePtrOperands, + attachOperands, waitOperands, reductionOperands; + SmallVector<Type, 8> operandTypes; + OpAsmParser::OperandType async, numGangs, numWorkers, vectorLength, ifCond, + selfCond; + bool hasAsync = false, hasNumGangs = false, hasNumWorkers = false, + hasVectorLength = false, hasIfCond = false, hasSelfCond = false; + + Type indexType = builder.getIndexType(); + Type i1Type = builder.getI1Type(); + + // async()? + if (failed(parseOptionalOperand(parser, ParallelOp::getAsyncKeyword(), async, + indexType, hasAsync, result))) + return failure(); + + // wait()? + if (failed(parseOperandList(parser, ParallelOp::getWaitKeyword(), + waitOperands, operandTypes, result))) + return failure(); + + // num_gangs(value)? + if (failed(parseOptionalOperand(parser, ParallelOp::getNumGangsKeyword(), + numGangs, indexType, hasNumGangs, result))) + return failure(); + + // num_workers(value)? + if (failed(parseOptionalOperand(parser, ParallelOp::getNumWorkersKeyword(), + numWorkers, indexType, hasNumWorkers, + result))) + return failure(); + + // vector_length(value)? + if (failed(parseOptionalOperand(parser, ParallelOp::getVectorLengthKeyword(), + vectorLength, indexType, hasVectorLength, + result))) + return failure(); + + // if()? + if (failed(parseOptionalOperand(parser, ParallelOp::getIfKeyword(), ifCond, + i1Type, hasIfCond, result))) + return failure(); + + // self()? + if (failed(parseOptionalOperand(parser, ParallelOp::getSelfKeyword(), + selfCond, i1Type, hasSelfCond, result))) + return failure(); + + // reduction()? + if (failed(parseOperandList(parser, ParallelOp::getReductionKeyword(), + reductionOperands, operandTypes, result))) + return failure(); + + // copy()? + if (failed(parseOperandList(parser, ParallelOp::getCopyKeyword(), + copyOperands, operandTypes, result))) + return failure(); + + // copyin()? + if (failed(parseOperandList(parser, ParallelOp::getCopyinKeyword(), + copyinOperands, operandTypes, result))) + return failure(); + + // copyout()? + if (failed(parseOperandList(parser, ParallelOp::getCopyoutKeyword(), + copyoutOperands, operandTypes, result))) + return failure(); + + // create()? + if (failed(parseOperandList(parser, ParallelOp::getCreateKeyword(), + createOperands, operandTypes, result))) + return failure(); + + // no_create()? + if (failed(parseOperandList(parser, ParallelOp::getNoCreateKeyword(), + noCreateOperands, operandTypes, result))) + return failure(); + + // present()? + if (failed(parseOperandList(parser, ParallelOp::getPresentKeyword(), + presentOperands, operandTypes, result))) + return failure(); + + // deviceptr()? + if (failed(parseOperandList(parser, ParallelOp::getDevicePtrKeyword(), + devicePtrOperands, operandTypes, result))) + return failure(); + + // attach()? + if (failed(parseOperandList(parser, ParallelOp::getAttachKeyword(), + attachOperands, operandTypes, result))) + return failure(); + + // private()? + if (failed(parseOperandList(parser, ParallelOp::getPrivateKeyword(), + privateOperands, operandTypes, result))) + return failure(); + + // firstprivate()? + if (failed(parseOperandList(parser, ParallelOp::getFirstPrivateKeyword(), + firstprivateOperands, operandTypes, result))) + return failure(); + + // Parallel op region + if (failed(parseRegions<ParallelOp>(parser, result))) + return failure(); + + result.addAttribute(ParallelOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr( + {static_cast<int32_t>(hasAsync ? 1 : 0), + static_cast<int32_t>(waitOperands.size()), + static_cast<int32_t>(hasNumGangs ? 1 : 0), + static_cast<int32_t>(hasNumWorkers ? 1 : 0), + static_cast<int32_t>(hasVectorLength ? 1 : 0), + static_cast<int32_t>(hasIfCond ? 1 : 0), + static_cast<int32_t>(hasSelfCond ? 1 : 0), + static_cast<int32_t>(reductionOperands.size()), + static_cast<int32_t>(copyOperands.size()), + static_cast<int32_t>(copyinOperands.size()), + static_cast<int32_t>(copyoutOperands.size()), + static_cast<int32_t>(createOperands.size()), + static_cast<int32_t>(noCreateOperands.size()), + static_cast<int32_t>(presentOperands.size()), + static_cast<int32_t>(devicePtrOperands.size()), + static_cast<int32_t>(attachOperands.size()), + static_cast<int32_t>(privateOperands.size()), + static_cast<int32_t>(firstprivateOperands.size())})); + + // Additional attributes + if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) + return failure(); + + return success(); +} + +static void print(OpAsmPrinter &printer, ParallelOp &op) { + printer << ParallelOp::getOperationName(); + + // async()? + if (auto async = op.async()) + printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ")"; + + // wait()? + printOperandList(op.waitOperands(), ParallelOp::getWaitKeyword(), printer); + + // num_gangs()? + if (auto numGangs = op.numGangs()) + printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs + << ")"; + + // num_workers()? + if (auto numWorkers = op.numWorkers()) + printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers + << ")"; + + // if()? + if (Value ifCond = op.ifCond()) + printer << " " << ParallelOp::getIfKeyword() << "(" << ifCond << ")"; + + // self()? + if (Value selfCond = op.selfCond()) + printer << " " << ParallelOp::getSelfKeyword() << "(" << selfCond << ")"; + + // reduction()? + printOperandList(op.reductionOperands(), ParallelOp::getReductionKeyword(), + printer); + + // copy()? + printOperandList(op.copyOperands(), ParallelOp::getCopyKeyword(), printer); + + // copyin()? + printOperandList(op.copyinOperands(), ParallelOp::getCopyinKeyword(), + printer); + + // copyout()? + printOperandList(op.copyoutOperands(), ParallelOp::getCopyoutKeyword(), + printer); + + // create()? + printOperandList(op.createOperands(), ParallelOp::getCreateKeyword(), + printer); + + // no_create()? + printOperandList(op.noCreateOperands(), ParallelOp::getNoCreateKeyword(), + printer); + + // present()? + printOperandList(op.presentOperands(), ParallelOp::getPresentKeyword(), + printer); + + // deviceptr()? + printOperandList(op.devicePtrOperands(), ParallelOp::getDevicePtrKeyword(), + printer); + + // attach()? + printOperandList(op.attachOperands(), ParallelOp::getAttachKeyword(), + printer); + + // private()? + printOperandList(op.gangPrivateOperands(), ParallelOp::getPrivateKeyword(), + printer); + + // firstprivate()? + printOperandList(op.gangFirstPrivateOperands(), + ParallelOp::getFirstPrivateKeyword(), printer); + + printer.printRegion(op.region(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); + printer.printOptionalAttrDictWithKeyword( + op.getAttrs(), ParallelOp::getOperandSegmentSizeAttr()); +} + +//===----------------------------------------------------------------------===// +// DataOp +//===----------------------------------------------------------------------===// + +/// Parse acc.data operation +/// operation := `acc.parallel` `present(value-list)?` `copy(value-list)?` +/// `copyin(value-list)?` `copyout(value-list)?` +/// `create(value-list)?` `no_create(value-list)`? +/// `delete(value-list)?` `attach(value-list)`? +/// `detach(value-list)?` +/// region attr-dict? +static ParseResult parseDataOp(OpAsmParser &parser, OperationState &result) { + auto &builder = parser.getBuilder(); + SmallVector<OpAsmParser::OperandType, 8> presentOperands, copyOperands, + copyinOperands, copyoutOperands, createOperands, noCreateOperands, + deleteOperands, attachOperands, detachOperands; + SmallVector<Type, 8> operandsTypes; + + // present(value-list)? + if (failed(parseOperandList(parser, DataOp::getPresentKeyword(), + presentOperands, operandsTypes, result))) + return failure(); + + // copy(value-list)? + if (failed(parseOperandList(parser, DataOp::getCopyKeyword(), copyOperands, + operandsTypes, result))) + return failure(); + + // copyin(value-list)? + if (failed(parseOperandList(parser, DataOp::getCopyinKeyword(), + copyinOperands, operandsTypes, result))) + return failure(); + + // copyout(value-list)? + if (failed(parseOperandList(parser, DataOp::getCopyoutKeyword(), + copyoutOperands, operandsTypes, result))) + return failure(); + + // create(value-list)? + if (failed(parseOperandList(parser, DataOp::getCreateKeyword(), + createOperands, operandsTypes, result))) + return failure(); + + // no_create(value-list)? + if (failed(parseOperandList(parser, DataOp::getCreateKeyword(), + noCreateOperands, operandsTypes, result))) + return failure(); + + // delete(value-list)? + if (failed(parseOperandList(parser, DataOp::getDeleteKeyword(), + deleteOperands, operandsTypes, result))) + return failure(); + + // attach(value-list)? + if (failed(parseOperandList(parser, DataOp::getAttachKeyword(), + attachOperands, operandsTypes, result))) + return failure(); + + // detach(value-list)? + if (failed(parseOperandList(parser, DataOp::getDetachKeyword(), + detachOperands, operandsTypes, result))) + return failure(); + + // Data op region + if (failed(parseRegions<ParallelOp>(parser, result))) + return failure(); + + result.addAttribute( + ParallelOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr({static_cast<int32_t>(presentOperands.size()), + static_cast<int32_t>(copyOperands.size()), + static_cast<int32_t>(copyinOperands.size()), + static_cast<int32_t>(copyoutOperands.size()), + static_cast<int32_t>(createOperands.size()), + static_cast<int32_t>(noCreateOperands.size()), + static_cast<int32_t>(deleteOperands.size()), + static_cast<int32_t>(attachOperands.size()), + static_cast<int32_t>(detachOperands.size())})); + + // Additional attributes + if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) + return failure(); + + return success(); +} + +static void print(OpAsmPrinter &printer, DataOp &op) { + printer << DataOp::getOperationName(); + + // present(value-list)? + printOperandList(op.presentOperands(), DataOp::getPresentKeyword(), printer); + + // copy(value-list)? + printOperandList(op.copyOperands(), DataOp::getCopyKeyword(), printer); + + // copyin(value-list)? + printOperandList(op.copyinOperands(), DataOp::getCopyinKeyword(), printer); + + // copyout(value-list)? + printOperandList(op.copyoutOperands(), DataOp::getCopyoutKeyword(), printer); + + // create(value-list)? + printOperandList(op.createOperands(), DataOp::getCreateKeyword(), printer); + + // no_create(value-list)? + printOperandList(op.noCreateOperands(), DataOp::getNoCreateKeyword(), + printer); + + // delete(value-list)? + printOperandList(op.deleteOperands(), DataOp::getDeleteKeyword(), printer); + + // attach(value-list)? + printOperandList(op.attachOperands(), DataOp::getAttachKeyword(), printer); + + // detach(value-list)? + printOperandList(op.detachOperands(), DataOp::getDetachKeyword(), printer); + + printer.printRegion(op.region(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); + printer.printOptionalAttrDictWithKeyword( + op.getAttrs(), ParallelOp::getOperandSegmentSizeAttr()); +} + +//===----------------------------------------------------------------------===// +// LoopOp +//===----------------------------------------------------------------------===// + +/// Parse acc.loop operation +/// operation := `acc.loop` `gang?` `vector?` `seq?` `private(value-list)?` +/// region attr-dict? +static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) { + + Builder &builder = parser.getBuilder(); + unsigned executionMapping = 0; + SmallVector<Type, 8> operandTypes; + SmallVector<OpAsmParser::OperandType, 8> privateOperands, reductionOperands; + + // gang? + if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangAttrName()))) + executionMapping |= OpenACCExecMapping::GANG; + + // vector? + if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorAttrName()))) + executionMapping |= OpenACCExecMapping::VECTOR; + + // worker? + if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerAttrName()))) + executionMapping |= OpenACCExecMapping::WORKER; + + // seq? + if (succeeded(parser.parseOptionalKeyword(LoopOp::getSeqAttrName()))) + executionMapping |= OpenACCExecMapping::SEQ; + + // private()? + if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(), + privateOperands, operandTypes, result))) + return failure(); + + // reduction()? + if (failed(parseOperandList(parser, LoopOp::getReductionKeyword(), + reductionOperands, operandTypes, result))) + return failure(); + + if (executionMapping != 0) + result.addAttribute(LoopOp::getExecutionMappingAttrName(), + builder.getI64IntegerAttr(executionMapping)); + + // Parse optional results in case there is a reduce. + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + + if (failed(parseRegions<LoopOp>(parser, result))) + return failure(); + + result.addAttribute(LoopOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr( + {static_cast<int32_t>(privateOperands.size()), + static_cast<int32_t>(reductionOperands.size())})); + + if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) + return failure(); + + return success(); +} + +static void print(OpAsmPrinter &printer, LoopOp &op) { + printer << LoopOp::getOperationName(); + bool printBlockTerminators = false; + + unsigned execMapping = + (op.getAttrOfType<IntegerAttr>(LoopOp::getExecutionMappingAttrName()) != + nullptr) + ? op.getAttrOfType<IntegerAttr>(LoopOp::getExecutionMappingAttrName()) + .getInt() + : 0; + if ((execMapping & OpenACCExecMapping::GANG) == OpenACCExecMapping::GANG) + printer << " " << LoopOp::getGangAttrName(); + + if ((execMapping & OpenACCExecMapping::WORKER) == OpenACCExecMapping::WORKER) + printer << " " << LoopOp::getWorkerAttrName(); + + if ((execMapping & OpenACCExecMapping::VECTOR) == OpenACCExecMapping::VECTOR) + printer << " " << LoopOp::getVectorAttrName(); + + if ((execMapping & OpenACCExecMapping::SEQ) == OpenACCExecMapping::SEQ) + printer << " " << LoopOp::getSeqAttrName(); + + // private()? + printOperandList(op.privateOperands(), LoopOp::getPrivateKeyword(), printer); + + // reduction()? + printOperandList(op.reductionOperands(), LoopOp::getReductionKeyword(), + printer); + + if (op.getNumResults() > 0) { + printer << " -> (" << op.getResultTypes() << ")"; + printBlockTerminators = true; + } + + printer.printRegion(op.region(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/printBlockTerminators); + + printer.printOptionalAttrDictWithKeyword( + op.getAttrs(), {LoopOp::getExecutionMappingAttrName(), + LoopOp::getOperandSegmentSizeAttr()}); +} + +namespace mlir { +namespace acc { +#define GET_OP_CLASSES +#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" +} // namespace acc +} // namespace mlir diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -0,0 +1,166 @@ +// RUN: mlir-opt %s | FileCheck %s +// Verify the printed output can be parsed. +// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// Verify the generic form can be parsed. +// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s + +func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf32>) -> memref<10x10xf32> { + %c0 = constant 0 : index + %c10 = constant 10 : index + %c1 = constant 1 : index + + acc.parallel async(%c1) { + acc.loop gang vector { + scf.for %arg3 = %c0 to %c10 step %c1 { + scf.for %arg4 = %c0 to %c10 step %c1 { + scf.for %arg5 = %c0 to %c10 step %c1 { + %a = load %A[%arg3, %arg5] : memref<10x10xf32> + %b = load %B[%arg5, %arg4] : memref<10x10xf32> + %cij = load %C[%arg3, %arg4] : memref<10x10xf32> + %p = mulf %a, %b : f32 + %co = addf %cij, %p : f32 + store %co, %C[%arg3, %arg4] : memref<10x10xf32> + } + } + } + } attributes { collapse = 3 } + } + + return %C : memref<10x10xf32> +} + +// CHECK-LABEL: func @compute1( +// CHECK-NEXT: %{{.*}} = constant 0 : index +// CHECK-NEXT: %{{.*}} = constant 10 : index +// CHECK-NEXT: %{{.*}} = constant 1 : index +// CHECK-NEXT: acc.parallel async(%{{.*}}) { +// CHECK-NEXT: acc.loop gang vector { +// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> +// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> +// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> +// CHECK-NEXT: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32 +// CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 +// CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } attributes {collapse = 3 : i64} +// CHECK-NEXT: } +// CHECK-NEXT: return %{{.*}} : memref<10x10xf32> +// CHECK-NEXT: } + +func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf32>) -> memref<10x10xf32> { + %c0 = constant 0 : index + %c10 = constant 10 : index + %c1 = constant 1 : index + + acc.parallel { + acc.loop seq { + scf.for %arg3 = %c0 to %c10 step %c1 { + scf.for %arg4 = %c0 to %c10 step %c1 { + scf.for %arg5 = %c0 to %c10 step %c1 { + %a = load %A[%arg3, %arg5] : memref<10x10xf32> + %b = load %B[%arg5, %arg4] : memref<10x10xf32> + %cij = load %C[%arg3, %arg4] : memref<10x10xf32> + %p = mulf %a, %b : f32 + %co = addf %cij, %p : f32 + store %co, %C[%arg3, %arg4] : memref<10x10xf32> + } + } + } + } + } + + return %C : memref<10x10xf32> +} + +// CHECK-LABEL: func @compute2( +// CHECK-NEXT: %{{.*}} = constant 0 : index +// CHECK-NEXT: %{{.*}} = constant 10 : index +// CHECK-NEXT: %{{.*}} = constant 1 : index +// CHECK-NEXT: acc.parallel { +// CHECK-NEXT: acc.loop seq { +// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> +// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> +// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> +// CHECK-NEXT: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32 +// CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 +// CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return %{{.*}} : memref<10x10xf32> +// CHECK-NEXT: } + + +func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>, %d: memref<10xf32>) -> memref<10xf32> { + %lb = constant 0 : index + %st = constant 1 : index + %c10 = constant 10 : index + + acc.parallel num_gangs(%c10) num_workers(%c10) private(%c : memref<10xf32>) { + acc.loop gang { + scf.for %x = %lb to %c10 step %st { + acc.loop worker { + scf.for %y = %lb to %c10 step %st { + %axy = load %a[%x, %y] : memref<10x10xf32> + %bxy = load %b[%x, %y] : memref<10x10xf32> + %tmp = addf %axy, %bxy : f32 + store %tmp, %c[%y] : memref<10xf32> + } + } + + acc.loop seq { + // for i = 0 to 10 step 1 + // d[x] += c[i] + scf.for %i = %lb to %c10 step %st { + %ci = load %c[%i] : memref<10xf32> + %dx = load %d[%x] : memref<10xf32> + %z = addf %ci, %dx : f32 + store %z, %d[%x] : memref<10xf32> + } + } + } + } + } + + return %d : memref<10xf32> +} + +// CHECK: func @compute3({{.*}}: memref<10x10xf32>, {{.*}}: memref<10x10xf32>, [[ARG2:%.*]]: memref<10xf32>, {{.*}}: memref<10xf32>) -> memref<10xf32> { +// CHECK-NEXT: [[C0:%.*]] = constant 0 : index +// CHECK-NEXT: [[C1:%.*]] = constant 1 : index +// CHECK-NEXT: [[C10:%.*]] = constant 10 : index +// CHECK-NEXT: acc.parallel num_gangs([[C10]]) num_workers([[C10]]) private([[ARG2]]: memref<10xf32>) { +// CHECK-NEXT: acc.loop gang { +// CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] { +// CHECK-NEXT: acc.loop worker { +// CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] { +// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> +// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> +// CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 +// CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: acc.loop seq { +// CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] { +// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}] : memref<10xf32> +// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}] : memref<10xf32> +// CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 +// CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return %{{.*}} : memref<10xf32> +// CHECK-NEXT: } +