diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(GPU) add_subdirectory(Linalg) add_subdirectory(LLVMIR) +add_subdirectory(OpenACC) add_subdirectory(OpenMP) add_subdirectory(Quant) add_subdirectory(SCF) diff --git a/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt @@ -0,0 +1,9 @@ +set(LLVM_TARGET_DEFINITIONS OpenACCOps.td) +mlir_tablegen(OpenACCOpsDialect.h.inc -gen-dialect-decls -dialect=acc) +mlir_tablegen(OpenACCOps.h.inc -gen-op-decls) +mlir_tablegen(OpenACCOps.cpp.inc -gen-op-defs) +mlir_tablegen(OpenACCOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpenACCOpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(OpenACCOps -gen-dialect-doc OpenACCDialect Dialects/) +add_public_tablegen_target(MLIROpenACCOpsIncGen) + diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h @@ -0,0 +1,41 @@ +//===- OpenACC.h - MLIR OpenACC Dialect -------------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ============================================================================ +// +// This file declares the OpenACC dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_OPENACC_OPENACCDIALECT_H +#define MLIR_DIALECT_OPENACC_OPENACCDIALECT_H + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.h.inc" + +namespace mlir { +namespace acc { + +#define GET_OP_CLASSES +#include "mlir/Dialect/OpenACC/OpenACCOps.h.inc" + +#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.h.inc" + +enum OpenACCExecMapping { + NONE = 0, + VECTOR = 1, + WORKER = 2, + GANG = 4, + GANG_VECTOR = 5, + SEQ = 8 +}; + +} // end namespace acc +} // end namespace mlir + +#endif // MLIR_DIALECT_OPENACC_OPENACCDIALECT_H diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -0,0 +1,182 @@ +//===- OpenACC.td - OpenACC operation definitions ----------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ============================================================================= +// +// Defines MLIR OpenACC operations. +// +//===----------------------------------------------------------------------===// + +#ifndef OPENACC_OPS +#define OPENACC_OPS + +include "mlir/IR/OpBase.td" + + +def OpenACC_Dialect : Dialect { + let name = "acc"; + + let description = [{ + An OpenACC dialect for MLIR. + + This dialect models the construct from the OpenACC 3.0 directive language. + }]; + + let cppNamespace = "acc"; +} + +// Base class for OpenACC dialect ops. +class OpenACC_Op traits = []> : + Op; + + +//===----------------------------------------------------------------------===// +// 2.5.1 parallel Construct +//===----------------------------------------------------------------------===// + +def OpenACC_ParallelOp : OpenACC_Op<"parallel", + [SingleBlockImplicitTerminator<"YieldOp">, + AttrSizedOperandSegments]> { + let summary = "parallel construct"; + let description = [{ + The "acc.parallel" operation represents a parallel construct block. It has + one region to be executued in parallel on the current device. + + acc.parallel { + ... // body + } + }]; + + let printer = [{ printParallelOp(*this, p); }]; + let parser = [{ return parseParallelOp(parser, result); }]; + let arguments = (ins Optional:$numGangs, + Optional:$numWorkers, + Optional:$vectorLength, + Optional:$ifCond, + Variadic:$gangPrivateOperands, + Variadic:$gangFirstPrivateOperands, + Variadic:$createOperands, + Variadic:$copyinOperands, + Variadic:$copyoutOperands); + let regions = (region SizedRegion<1>:$region); + + let extraClassDeclaration = [{ + Region &getBody() { return this->getOperation()->getRegion(0); } + static StringRef getNumGangsKeyword() { return "num_gangs"; } + static StringRef getNumWorkersKeyword() { return "num_workers"; } + static StringRef getVectorLengthKeyword() { return "vector_length"; } + static StringRef getPrivateKeyword() { return "private"; } + static StringRef getFirstPrivateKeyword() { return "firstprivate"; } + static StringRef getCreateKeyword() { return "create"; } + static StringRef getCopyinKeyword() { return "copyin"; } + static StringRef getCopyoutKeyword() { return "copyout"; } + static StringRef getIfKeyword() { return "if"; } + }]; +} + +//===----------------------------------------------------------------------===// +// 2.6.5 data Construct +//===----------------------------------------------------------------------===// + +def OpenACC_DataOp : OpenACC_Op<"data", + [SingleBlockImplicitTerminator<"DataEndOp">, + AttrSizedOperandSegments]> { + let summary = "data construct"; + let description = [{ + The "acc.data" operation represents a data construct. It defines vars to + be allocated in the current device memory for the duration of the region, + whether data should be copied from local memory to the current device + memory upon region entry , and copied from device memory to local memory + upon region exit. + + acc.data { + ... // body + } + }]; + + let printer = [{ printDataOp(*this, p); }]; + let parser = [{ return parseDataOp(parser, result); }]; + let arguments = (ins Variadic:$presentOperands, + Variadic:$copyOperands, + Variadic:$copyinOperands, + Variadic:$copyoutOperands, + Variadic:$createOperands, + Variadic:$noCreateOperands, + Variadic:$deleteOperands, + Variadic:$attachOperands, + Variadic:$detachOperands); + let regions = (region SizedRegion<1>:$region); + let extraClassDeclaration = [{ + Region &getBody() { return this->getOperation()->getRegion(0); } + static StringRef getAttachKeyword() { return "attach"; } + static StringRef getDeleteKeyword() { return "delete"; } + static StringRef getDetachKeyword() { return "detach"; } + static StringRef getCopyinKeyword() { return "copyin"; } + static StringRef getCopyKeyword() { return "copy"; } + static StringRef getCopyoutKeyword() { return "copyout"; } + static StringRef getCreateKeyword() { return "create"; } + static StringRef getNoCreateKeyword() { return "no_create"; } + static StringRef getPresentKeyword() { return "present"; } + }]; +} + +def OpenACC_DataEndOp : OpenACC_Op<"_data_end", [Terminator]> { + let summary = "terminator for OpenACC data regions"; + let description = [{ + A terminator operation for regions that appear in the body of OpenACC data + operation. Data construct regions are not expected to return any value so + the terminator takes no operands. The terminator op returns control to the + enclosing op. + }]; + + let assemblyFormat = "attr-dict"; +} + +//===----------------------------------------------------------------------===// +// 2.9 loop Construct +//===----------------------------------------------------------------------===// + +def OpenACC_LoopOp : OpenACC_Op<"loop", + [SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = "loop construct"; + let description = [{ + The "acc.loop" operation represent the acc loop construct + }]; + + let parser = [{ return parseLoopOp(parser, result); }]; + let printer = [{ printLoopOp(*this, p); }]; + let results = (outs Variadic:$results); + let regions = (region SizedRegion<1>:$region); + + let extraClassDeclaration = [{ + Region &getBody() { return this->getOperation()->getRegion(0); } + static StringRef getCollapseAttrName() { return "collapse"; } + static StringRef getExecutionMappingAttrName() { return "exec_mapping"; } + static StringRef getGangAttrName() { return "gang"; } + static StringRef getSeqAttrName() { return "seq"; } + static StringRef getVectorAttrName() { return "vector"; } + static StringRef getWorkerAttrName() { return "worker"; } + }]; +} + +// Yield operation for the acc.loop and acc.parallel operations. +def OpenACC_YieldOp : OpenACC_Op<"yield", [Terminator]> { + let summary = "Acc yield and termination operation"; + let description = [{ + + }]; + let arguments = (ins Variadic:$operands); + let verifier = [{ return ::verify(*this); }]; + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &result", + [{ /* nothing to do */ }]> + ]; + + let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; +} + +#endif // OPENACC_OPS + diff --git a/mlir/include/mlir/IR/DialectSymbolRegistry.def b/mlir/include/mlir/IR/DialectSymbolRegistry.def --- a/mlir/include/mlir/IR/DialectSymbolRegistry.def +++ b/mlir/include/mlir/IR/DialectSymbolRegistry.def @@ -20,6 +20,7 @@ DEFINE_SYM_KIND_RANGE(IREE) // IREE stands for IR Execution Engine DEFINE_SYM_KIND_RANGE(LINALG) // Linear Algebra Dialect DEFINE_SYM_KIND_RANGE(FIR) // Flang Fortran IR Dialect +DEFINE_SYM_KIND_RANGE(OPENACC) // OpenACC IR Dialect DEFINE_SYM_KIND_RANGE(OPENMP) // OpenMP IR Dialect DEFINE_SYM_KIND_RANGE(TOY) // Toy language (tutorial) Dialect DEFINE_SYM_KIND_RANGE(SPIRV) // SPIR-V dialect diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -22,6 +22,7 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/SCF/SCF.h" @@ -38,6 +39,7 @@ // all the possible dialects to be made available to the context automatically. inline void registerAllDialects() { static bool init_once = []() { + registerDialect(); registerDialect(); registerDialect(); registerDialect(); diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(GPU) add_subdirectory(Linalg) add_subdirectory(LLVMIR) +add_subdirectory(OpenACC) add_subdirectory(OpenMP) add_subdirectory(Quant) add_subdirectory(SCF) diff --git a/mlir/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIROpenACC + IR/OpenACC.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC + + DEPENDS + MLIROpenACCOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + ) + diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -0,0 +1,490 @@ +//===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ============================================================================= + +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Value.h" + +#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc" + +using namespace mlir; +using namespace acc; + +//===----------------------------------------------------------------------===// +// OpenACC operations +//===----------------------------------------------------------------------===// + +OpenACCDialect::OpenACCDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context) { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" + >(); + allowsUnknownOperations(); +} + +template +static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, + unsigned int nRegions = 1) { + + llvm::SmallVector regions; + for (unsigned int i = 0; i < nRegions; ++i) { + regions.push_back(state.addRegion()); + } + + for (auto ®ion : regions) { + if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{})) { + return failure(); + } + StructureOp::ensureTerminator(*region, parser.getBuilder(), state.location); + } + + return success(); +} + +static ParseResult parseOptionalAttributes(OpAsmParser &parser, + OperationState &state) { + if (succeeded(parser.parseOptionalKeyword("attributes"))) { + if (parser.parseOptionalAttrDict(state.attributes)) + return failure(); + } + return success(); +} + +static ParseResult +parseOperandList(OpAsmParser &parser, StringRef keyword, + SmallVectorImpl &args, + SmallVectorImpl &argTypes, OperationState &result) { + if (failed(parser.parseOptionalKeyword(keyword))) + return success(); + + if (failed(parser.parseLParen())) + return failure(); + + // Exit if empty list + if (succeeded(parser.parseOptionalRParen())) + return success(); + + do { + OpAsmParser::OperandType arg; + Type type; + + if (parser.parseRegionArgument(arg) || parser.parseColonType(type)) + return failure(); + + args.push_back(arg); + argTypes.push_back(type); + } while (succeeded(parser.parseOptionalComma())); + + if (failed(parser.parseRParen())) + return failure(); + + for (auto operandType : llvm::zip(args, argTypes)) { + if (parser.resolveOperand(std::get<0>(operandType), + std::get<1>(operandType), result.operands)) + return failure(); + } + return success(); +} + +static void printOperandList(Operation::operand_range operands, + StringRef listName, OpAsmPrinter &printer) { + if (operands.size() > 0) { + printer << " " << listName << "("; + for (auto operand : operands) { + printer << operand << ": " << operand.getType(); + } + printer << ")"; + } +} + +static ParseResult parseOptionalOperand(OpAsmParser &parser, StringRef keyword, + OpAsmParser::OperandType &operand, + Type &type, bool &hasOptional, + OperationState &result) { + hasOptional = false; + if (succeeded(parser.parseOptionalKeyword(keyword))) { + hasOptional = true; + if (parser.parseLParen() || parser.parseOperand(operand) || + parser.resolveOperand(operand, type, result.operands) || + parser.parseRParen()) + return failure(); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// ParallelOp +//===----------------------------------------------------------------------===// + +// Parse acc.parallel operation +// operation := `acc.parallel` `num_gangs(value)?` `num_workers(value)?` +// `vector_length(value)?` `private(value)?` +// `private(value-list)?` +// `firstprivate(value-list)`? +// region attr-dict? +static ParseResult parseParallelOp(OpAsmParser &parser, + OperationState &result) { + auto &builder = parser.getBuilder(); + SmallVector privateOperands, + firstprivateOperands, createOperands, copyinOperands, copyoutOperands; + SmallVector operandTypes; + OpAsmParser::OperandType numGangs, numWorkers, vectorLength, ifCond; + bool hasNumGangs = false, hasNumWorkers = false, hasVectorLength = false, + hasIfCond = false; + + Type indexType = builder.getIndexType(); + Type i1Type = builder.getI1Type(); + + // num_gangs(value)? + if (failed(parseOptionalOperand(parser, ParallelOp::getNumGangsKeyword(), + numGangs, indexType, hasNumGangs, result))) + return failure(); + + // num_workers(value)? + if (failed(parseOptionalOperand(parser, ParallelOp::getNumWorkersKeyword(), + numWorkers, indexType, hasNumWorkers, + result))) + return failure(); + + // vector_length(value)? + if (failed(parseOptionalOperand(parser, ParallelOp::getVectorLengthKeyword(), + vectorLength, indexType, hasVectorLength, + result))) + return failure(); + + // private()? + if (failed(parseOperandList(parser, ParallelOp::getPrivateKeyword(), + privateOperands, operandTypes, result))) + return failure(); + + // firstprivate()? + if (failed(parseOperandList(parser, ParallelOp::getFirstPrivateKeyword(), + firstprivateOperands, operandTypes, result))) + return failure(); + + // create()? + if (failed(parseOperandList(parser, ParallelOp::getCreateKeyword(), + createOperands, operandTypes, result))) + return failure(); + + // copyin()? + if (failed(parseOperandList(parser, ParallelOp::getCopyinKeyword(), + copyinOperands, operandTypes, result))) + return failure(); + + // copyout()? + if (failed(parseOperandList(parser, ParallelOp::getCopyoutKeyword(), + copyoutOperands, operandTypes, result))) + return failure(); + + // if()? + if (failed(parseOptionalOperand(parser, ParallelOp::getIfKeyword(), ifCond, + i1Type, hasIfCond, result))) + return failure(); + + // Parallel op region + if (failed(parseRegions(parser, result))) + return failure(); + + result.addAttribute(ParallelOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr( + {static_cast(hasNumGangs ? 1 : 0), + static_cast(hasNumWorkers ? 1 : 0), + static_cast(hasVectorLength ? 1 : 0), + static_cast(hasIfCond ? 1 : 0), + static_cast(privateOperands.size()), + static_cast(firstprivateOperands.size()), + static_cast(createOperands.size()), + static_cast(copyinOperands.size()), + static_cast(copyoutOperands.size())})); + + // Additional attributes + if (failed(parseOptionalAttributes(parser, result))) + return failure(); + + return success(); +} + +static void printParallelOp(ParallelOp &op, OpAsmPrinter &printer) { + printer << ParallelOp::getOperationName(); + + // num_gangs()? + if (auto numGangs = op.numGangs()) + printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs + << ")"; + + // num_workers()? + if (auto numWorkers = op.numWorkers()) + printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers + << ")"; + + // private()? + printOperandList(op.gangPrivateOperands(), ParallelOp::getPrivateKeyword(), + printer); + + // firstprivate()? + printOperandList(op.gangFirstPrivateOperands(), + ParallelOp::getFirstPrivateKeyword(), printer); + + // create()? + printOperandList(op.createOperands(), ParallelOp::getCreateKeyword(), + printer); + + // copyin()? + printOperandList(op.copyinOperands(), ParallelOp::getCopyinKeyword(), + printer); + + // copyout()? + printOperandList(op.copyoutOperands(), ParallelOp::getCopyoutKeyword(), + printer); + + // if()? + if (auto ifCond = op.ifCond()) + printer << " " << ParallelOp::getIfKeyword() << "(" << ifCond << ")"; + + printer.printRegion(op.getBody(), false, false); + printer.printOptionalAttrDictWithKeyword( + op.getAttrs(), ParallelOp::getOperandSegmentSizeAttr()); +} + +//===----------------------------------------------------------------------===// +// DataOp +//===----------------------------------------------------------------------===// + +// Parse acc.data operation +// operation := `acc.parallel` `present(value-list)?` `copy(value-list)?` +// `copyin(value-list)?` `copyout(value-list)?` +// `create(value-list)?` `no_create(value-list)`? +// `delete(value-list)?` `attach(value-list)`? +// `detach(value-list)?` +// region attr-dict? +static ParseResult parseDataOp(OpAsmParser &parser, OperationState &result) { + auto &builder = parser.getBuilder(); + SmallVector presentOperands, copyOperands, + copyinOperands, copyoutOperands, createOperands, noCreateOperands, + deleteOperands, attachOperands, detachOperands; + SmallVector operandsTypes; + + // present(value-list)? + if (failed(parseOperandList(parser, DataOp::getPresentKeyword(), + presentOperands, operandsTypes, result))) + return failure(); + + // copy(value-list)? + if (failed(parseOperandList(parser, DataOp::getCopyKeyword(), copyOperands, + operandsTypes, result))) + return failure(); + + // copyin(value-list)? + if (failed(parseOperandList(parser, DataOp::getCopyinKeyword(), + copyinOperands, operandsTypes, result))) + return failure(); + + // copyout(value-list)? + if (failed(parseOperandList(parser, DataOp::getCopyoutKeyword(), + copyoutOperands, operandsTypes, result))) + return failure(); + + // create(value-list)? + if (failed(parseOperandList(parser, DataOp::getCreateKeyword(), + createOperands, operandsTypes, result))) + return failure(); + + // no_create(value-list)? + if (failed(parseOperandList(parser, DataOp::getCreateKeyword(), + noCreateOperands, operandsTypes, result))) + return failure(); + + // delete(value-list)? + if (failed(parseOperandList(parser, DataOp::getDeleteKeyword(), + deleteOperands, operandsTypes, result))) + return failure(); + + // attach(value-list)? + if (failed(parseOperandList(parser, DataOp::getAttachKeyword(), + attachOperands, operandsTypes, result))) + return failure(); + + // detach(value-list)? + if (failed(parseOperandList(parser, DataOp::getDetachKeyword(), + detachOperands, operandsTypes, result))) + return failure(); + + // Data op region + if (failed(parseRegions(parser, result))) + return failure(); + + result.addAttribute( + ParallelOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr({static_cast(presentOperands.size()), + static_cast(copyOperands.size()), + static_cast(copyinOperands.size()), + static_cast(copyoutOperands.size()), + static_cast(createOperands.size()), + static_cast(noCreateOperands.size()), + static_cast(deleteOperands.size()), + static_cast(attachOperands.size()), + static_cast(detachOperands.size())})); + + // Additional attributes + if (failed(parseOptionalAttributes(parser, result))) + return failure(); + + return success(); +} + +static void printDataOp(DataOp &op, OpAsmPrinter &printer) { + printer << DataOp::getOperationName(); + + // present(value-list)? + printOperandList(op.presentOperands(), DataOp::getPresentKeyword(), printer); + + // copy(value-list)? + printOperandList(op.copyOperands(), DataOp::getCopyKeyword(), printer); + + // copyin(value-list)? + printOperandList(op.copyinOperands(), DataOp::getCopyinKeyword(), printer); + + // copyout(value-list)? + printOperandList(op.copyoutOperands(), DataOp::getCopyoutKeyword(), printer); + + // create(value-list)? + printOperandList(op.createOperands(), DataOp::getCreateKeyword(), printer); + + // no_create(value-list)? + printOperandList(op.noCreateOperands(), DataOp::getNoCreateKeyword(), + printer); + + // delete(value-list)? + printOperandList(op.deleteOperands(), DataOp::getDeleteKeyword(), printer); + + // attach(value-list)? + printOperandList(op.attachOperands(), DataOp::getAttachKeyword(), printer); + + // detach(value-list)? + printOperandList(op.detachOperands(), DataOp::getDetachKeyword(), printer); + + printer.printRegion(op.getBody(), false, false); + printer.printOptionalAttrDictWithKeyword( + op.getAttrs(), ParallelOp::getOperandSegmentSizeAttr()); +} + +//===----------------------------------------------------------------------===// +// LoopOp +//===----------------------------------------------------------------------===// + +// Parse acc.loop operation +// operation := `acc.loop` `gang?` `vector?` `seq?` `private(value-list)?` +// region attr-dict? +static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) { + + Builder &builder = parser.getBuilder(); + unsigned executionMapping = 0; + + // gang? + if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangAttrName()))) + executionMapping |= OpenACCExecMapping::GANG; + + // vector? + if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorAttrName()))) + executionMapping |= OpenACCExecMapping::VECTOR; + + // worker? + if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerAttrName()))) + executionMapping |= OpenACCExecMapping::WORKER; + + // seq? + if (succeeded(parser.parseOptionalKeyword(LoopOp::getSeqAttrName()))) + executionMapping |= OpenACCExecMapping::SEQ; + + if (executionMapping != 0) + result.addAttribute(LoopOp::getExecutionMappingAttrName(), + builder.getI64IntegerAttr(executionMapping)); + + // Parse optional results in case there is a reduce. + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + + if (failed(parseRegions(parser, result))) + return failure(); + + if (failed(parseOptionalAttributes(parser, result))) + return failure(); + + return success(); +} + +static void printLoopOp(LoopOp &op, OpAsmPrinter &printer) { + printer << LoopOp::getOperationName(); + bool printBlockTerminators = false; + + unsigned execMapping = + (op.getAttrOfType(LoopOp::getExecutionMappingAttrName()) != + nullptr) + ? op.getAttrOfType(LoopOp::getExecutionMappingAttrName()) + .getInt() + : 0; + if ((execMapping & OpenACCExecMapping::GANG) == OpenACCExecMapping::GANG) + printer << " " << LoopOp::getGangAttrName(); + + if ((execMapping & OpenACCExecMapping::WORKER) == OpenACCExecMapping::WORKER) + printer << " " << LoopOp::getWorkerAttrName(); + + if ((execMapping & OpenACCExecMapping::VECTOR) == OpenACCExecMapping::VECTOR) + printer << " " << LoopOp::getVectorAttrName(); + + if ((execMapping & OpenACCExecMapping::SEQ) == OpenACCExecMapping::SEQ) + printer << " " << LoopOp::getSeqAttrName(); + + if (op.getNumResults() > 0) { + printer << " -> (" << op.getResultTypes() << ")"; + printBlockTerminators = true; + } + + printer.printRegion(op.getBody(), false, printBlockTerminators); + + printer.printOptionalAttrDictWithKeyword( + op.getAttrs(), {LoopOp::getExecutionMappingAttrName()}); +} + +//===----------------------------------------------------------------------===// +// YieldOp +//===----------------------------------------------------------------------===// +static LogicalResult verify(YieldOp op) { + auto parentOp = op.getParentOp(); + auto results = parentOp->getResults(); + auto operands = op.getOperands(); + + if (isa(parentOp) || isa(parentOp)) { + if (parentOp->getNumResults() != op.getNumOperands()) + return op.emitOpError() << "parent of yield must have same number of " + "results as the yield operands"; + for (auto e : llvm::zip(results, operands)) { + if (std::get<0>(e).getType() != std::get<1>(e).getType()) + return op.emitOpError() + << "types mismatch between yield op and its parent"; + } + } else { + return op.emitOpError() << "yield only terminates Loop or Parallel regions"; + } + return success(); +} + +namespace mlir { +namespace acc { +#define GET_OP_CLASSES +#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" +} // namespace acc +} // namespace mlir diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -0,0 +1,166 @@ +// RUN: mlir-opt %s | FileCheck %s +// Verify the printed output can be parsed. +// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// Verify the generic form can be parsed. +// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s + +func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf32>) -> memref<10x10xf32> { + %c0 = constant 0 : index + %c10 = constant 10 : index + %c1 = constant 1 : index + + acc.parallel { + acc.loop gang vector { + scf.for %arg3 = %c0 to %c10 step %c1 { + scf.for %arg4 = %c0 to %c10 step %c1 { + scf.for %arg5 = %c0 to %c10 step %c1 { + %a = load %A[%arg3, %arg5] : memref<10x10xf32> + %b = load %B[%arg5, %arg4] : memref<10x10xf32> + %cij = load %C[%arg3, %arg4] : memref<10x10xf32> + %p = mulf %a, %b : f32 + %co = addf %cij, %p : f32 + store %co, %C[%arg3, %arg4] : memref<10x10xf32> + } + } + } + } attributes { collapse = 3 } + } attributes { async = 1 } + + return %C : memref<10x10xf32> +} + +// CHECK-LABEL: func @compute1( +// CHECK-NEXT: %{{.*}} = constant 0 : index +// CHECK-NEXT: %{{.*}} = constant 10 : index +// CHECK-NEXT: %{{.*}} = constant 1 : index +// CHECK-NEXT: acc.parallel { +// CHECK-NEXT: acc.loop gang vector { +// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> +// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> +// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> +// CHECK-NEXT: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32 +// CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 +// CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } attributes {collapse = 3 : i64} +// CHECK-NEXT: } attributes {async = 1 : i64} +// CHECK-NEXT: return %{{.*}} : memref<10x10xf32> +// CHECK-NEXT: } + +func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf32>) -> memref<10x10xf32> { + %c0 = constant 0 : index + %c10 = constant 10 : index + %c1 = constant 1 : index + + acc.parallel { + acc.loop seq { + scf.for %arg3 = %c0 to %c10 step %c1 { + scf.for %arg4 = %c0 to %c10 step %c1 { + scf.for %arg5 = %c0 to %c10 step %c1 { + %a = load %A[%arg3, %arg5] : memref<10x10xf32> + %b = load %B[%arg5, %arg4] : memref<10x10xf32> + %cij = load %C[%arg3, %arg4] : memref<10x10xf32> + %p = mulf %a, %b : f32 + %co = addf %cij, %p : f32 + store %co, %C[%arg3, %arg4] : memref<10x10xf32> + } + } + } + } + } + + return %C : memref<10x10xf32> +} + +// CHECK-LABEL: func @compute2( +// CHECK-NEXT: %{{.*}} = constant 0 : index +// CHECK-NEXT: %{{.*}} = constant 10 : index +// CHECK-NEXT: %{{.*}} = constant 1 : index +// CHECK-NEXT: acc.parallel { +// CHECK-NEXT: acc.loop seq { +// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> +// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> +// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> +// CHECK-NEXT: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32 +// CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 +// CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return %{{.*}} : memref<10x10xf32> +// CHECK-NEXT: } + + +func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>, %d: memref<10xf32>) -> memref<10xf32> { + %lb = constant 0 : index + %st = constant 1 : index + %c10 = constant 10 : index + + acc.parallel num_gangs(%c10) num_workers(%c10) private(%c : memref<10xf32>) { + acc.loop gang { + scf.for %x = %lb to %c10 step %st { + acc.loop worker { + scf.for %y = %lb to %c10 step %st { + %axy = load %a[%x, %y] : memref<10x10xf32> + %bxy = load %b[%x, %y] : memref<10x10xf32> + %tmp = addf %axy, %bxy : f32 + store %tmp, %c[%y] : memref<10xf32> + } + } + + acc.loop seq { + // for i = 0 to 10 step 1 + // d[x] += c[i] + scf.for %i = %lb to %c10 step %st { + %ci = load %c[%i] : memref<10xf32> + %dx = load %d[%x] : memref<10xf32> + %z = addf %ci, %dx : f32 + store %z, %d[%x] : memref<10xf32> + } + } + } + } + } + + return %d : memref<10xf32> +} + +// CHECK: func @compute3({{.*}}: memref<10x10xf32>, {{.*}}: memref<10x10xf32>, [[ARG2:%.*]]: memref<10xf32>, {{.*}}: memref<10xf32>) -> memref<10xf32> { +// CHECK-NEXT: [[C0:%.*]] = constant 0 : index +// CHECK-NEXT: [[C1:%.*]] = constant 1 : index +// CHECK-NEXT: [[C10:%.*]] = constant 10 : index +// CHECK-NEXT: acc.parallel num_gangs([[C10]]) num_workers([[C10]]) private([[ARG2]]: memref<10xf32>) { +// CHECK-NEXT: acc.loop gang { +// CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] { +// CHECK-NEXT: acc.loop worker { +// CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] { +// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> +// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> +// CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 +// CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: acc.loop seq { +// CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] { +// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}] : memref<10xf32> +// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}] : memref<10xf32> +// CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 +// CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return %{{.*}} : memref<10xf32> +// CHECK-NEXT: } +