diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -101,15 +101,27 @@ return success(); } +static ParseResult parseOperandAndType(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType operand; + Type type; + if (parser.parseOperand(operand) || parser.parseColonType(type) || + parser.resolveOperand(operand, type, result.operands)) + return failure(); + return success(); +} + +/// Parse optional operand and its type wrapped in parenthesis prefixed with +/// a keyword. +/// Example: +/// keyword `(` %vectorLength: i64 `)` static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser, StringRef keyword, OperationState &result) { OpAsmParser::OperandType operand; Type type; if (succeeded(parser.parseOptionalKeyword(keyword))) { - if (parser.parseLParen() || parser.parseOperand(operand) || - parser.parseColonType(type) || - parser.resolveOperand(operand, type, result.operands) || + if (parser.parseLParen() || parseOperandAndType(parser, result) || parser.parseRParen()) return failure(); return success(); @@ -117,6 +129,33 @@ return llvm::None; } +/// Parse optional operand and its type wrapped in parenthesis. +/// Example: +/// `(` %vectorLength: i64 `)` +static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType operand; + Type type; + if (succeeded(parser.parseOptionalLParen())) { + if (parseOperandAndType(parser, result) || parser.parseRParen()) + return failure(); + return success(); + } + return llvm::None; +} + +/// Parse optional operand with its type prefixed with prefixKeyword `=`. +/// Example: +/// num=%gangNum: i32 +static OptionalParseResult parserOptionalOperandAndTypeWithPrefix( + OpAsmParser &parser, OperationState &result, StringRef prefixKeyword) { + if (succeeded(parser.parseOptionalKeyword(prefixKeyword))) { + parser.parseEqual(); + return parseOperandAndType(parser, result); + } + return llvm::None; +} + //===----------------------------------------------------------------------===// // ParallelOp //===----------------------------------------------------------------------===// @@ -600,10 +639,7 @@ SmallVector operandTypes; SmallVector privateOperands, reductionOperands; SmallVector tileOperands; - bool hasWorkerNum = false, hasVectorLength = false, hasGangNum = false; - bool hasGangStatic = false; - OpAsmParser::OperandType workerNum, vectorLength, gangNum, gangStatic; - Type gangNumType, gangStaticType, workerType, vectorLengthType; + OptionalParseResult gangNum, gangStatic, worker, vector; // gang? if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword()))) @@ -611,25 +647,16 @@ // optional gang operand if (succeeded(parser.parseOptionalLParen())) { - if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangNumKeyword()))) { - hasGangNum = true; - parser.parseEqual(); - if (parser.parseOperand(gangNum) || parser.parseColonType(gangNumType) || - parser.resolveOperand(gangNum, gangNumType, result.operands)) { - return failure(); - } - } + gangNum = parserOptionalOperandAndTypeWithPrefix( + parser, result, LoopOp::getGangNumKeyword()); + if (gangNum.hasValue() && failed(*gangNum)) + return failure(); + parser.parseOptionalComma(); + gangStatic = parserOptionalOperandAndTypeWithPrefix( + parser, result, LoopOp::getGangStaticKeyword()); + if (gangStatic.hasValue() && failed(*gangStatic)) + return failure(); parser.parseOptionalComma(); - if (succeeded( - parser.parseOptionalKeyword(LoopOp::getGangStaticKeyword()))) { - hasGangStatic = true; - parser.parseEqual(); - if (parser.parseOperand(gangStatic) || - parser.parseColonType(gangStaticType) || - parser.resolveOperand(gangStatic, gangStaticType, result.operands)) { - return failure(); - } - } if (failed(parser.parseRParen())) return failure(); } @@ -639,30 +666,18 @@ executionMapping |= OpenACCExecMapping::WORKER; // optional worker operand - if (succeeded(parser.parseOptionalLParen())) { - hasWorkerNum = true; - if (parser.parseOperand(workerNum) || parser.parseColonType(workerType) || - parser.resolveOperand(workerNum, workerType, result.operands) || - parser.parseRParen()) { - return failure(); - } - } + worker = parseOptionalOperandAndType(parser, result); + if (worker.hasValue() && failed(*worker)) + return failure(); // vector? if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorKeyword()))) executionMapping |= OpenACCExecMapping::VECTOR; // optional vector operand - if (succeeded(parser.parseOptionalLParen())) { - hasVectorLength = true; - if (parser.parseOperand(vectorLength) || - parser.parseColonType(vectorLengthType) || - parser.resolveOperand(vectorLength, vectorLengthType, - result.operands) || - parser.parseRParen()) { - return failure(); - } - } + vector = parseOptionalOperandAndType(parser, result); + if (vector.hasValue() && failed(*vector)) + return failure(); // tile()? if (failed(parseOperandList(parser, LoopOp::getTileKeyword(), tileOperands, @@ -692,10 +707,10 @@ result.addAttribute(LoopOp::getOperandSegmentSizeAttr(), builder.getI32VectorAttr( - {static_cast(hasGangNum ? 1 : 0), - static_cast(hasGangStatic ? 1 : 0), - static_cast(hasWorkerNum ? 1 : 0), - static_cast(hasVectorLength ? 1 : 0), + {static_cast(gangNum.hasValue() ? 1 : 0), + static_cast(gangStatic.hasValue() ? 1 : 0), + static_cast(worker.hasValue() ? 1 : 0), + static_cast(vector.hasValue() ? 1 : 0), static_cast(tileOperands.size()), static_cast(privateOperands.size()), static_cast(reductionOperands.size())}));