diff --git a/mlir/lib/Parser/AffineParser.cpp b/mlir/lib/Parser/AffineParser.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Parser/AffineParser.cpp @@ -0,0 +1,726 @@ +//===- AffineParser.cpp - MLIR Affine Parser ------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a parser for Affine structures. +// +//===----------------------------------------------------------------------===// + +#include "Parser.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/IntegerSet.h" + +using namespace mlir; +using namespace mlir::detail; +using llvm::SMLoc; + +namespace { + +/// Lower precedence ops (all at the same precedence level). LNoOp is false in +/// the boolean sense. +enum AffineLowPrecOp { + /// Null value. + LNoOp, + Add, + Sub +}; + +/// Higher precedence ops - all at the same precedence level. HNoOp is false +/// in the boolean sense. +enum AffineHighPrecOp { + /// Null value. + HNoOp, + Mul, + FloorDiv, + CeilDiv, + Mod +}; + +/// This is a specialized parser for affine structures (affine maps, affine +/// expressions, and integer sets), maintaining the state transient to their +/// bodies. +class AffineParser : public Parser { +public: + AffineParser(ParserState &state, bool allowParsingSSAIds = false, + function_ref parseElement = nullptr) + : Parser(state), allowParsingSSAIds(allowParsingSSAIds), + parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {} + + AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols); + ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set); + IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols); + ParseResult parseAffineMapOfSSAIds(AffineMap &map, + OpAsmParser::Delimiter delimiter); + void getDimsAndSymbolSSAIds(SmallVectorImpl &dimAndSymbolSSAIds, + unsigned &numDims); + +private: + // Binary affine op parsing. + AffineLowPrecOp consumeIfLowPrecOp(); + AffineHighPrecOp consumeIfHighPrecOp(); + + // Identifier lists for polyhedral structures. + ParseResult parseDimIdList(unsigned &numDims); + ParseResult parseSymbolIdList(unsigned &numSymbols); + ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims, + unsigned &numSymbols); + ParseResult parseIdentifierDefinition(AffineExpr idExpr); + + AffineExpr parseAffineExpr(); + AffineExpr parseParentheticalExpr(); + AffineExpr parseNegateExpression(AffineExpr lhs); + AffineExpr parseIntegerExpr(); + AffineExpr parseBareIdExpr(); + AffineExpr parseSSAIdExpr(bool isSymbol); + AffineExpr parseSymbolSSAIdExpr(); + + AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs, + AffineExpr rhs, llvm::SMLoc opLoc); + AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs, + AffineExpr rhs); + AffineExpr parseAffineOperandExpr(AffineExpr lhs); + AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp); + AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp, + llvm::SMLoc llhsOpLoc); + AffineExpr parseAffineConstraint(bool *isEq); + +private: + bool allowParsingSSAIds; + function_ref parseElement; + unsigned numDimOperands; + unsigned numSymbolOperands; + SmallVector, 4> dimsAndSymbols; +}; +} // end anonymous namespace + +/// Create an affine binary high precedence op expression (mul's, div's, mod). +/// opLoc is the location of the op token to be used to report errors +/// for non-conforming expressions. +AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op, + AffineExpr lhs, AffineExpr rhs, + SMLoc opLoc) { + // TODO: make the error location info accurate. + switch (op) { + case Mul: + if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) { + emitError(opLoc, "non-affine expression: at least one of the multiply " + "operands has to be either a constant or symbolic"); + return nullptr; + } + return lhs * rhs; + case FloorDiv: + if (!rhs.isSymbolicOrConstant()) { + emitError(opLoc, "non-affine expression: right operand of floordiv " + "has to be either a constant or symbolic"); + return nullptr; + } + return lhs.floorDiv(rhs); + case CeilDiv: + if (!rhs.isSymbolicOrConstant()) { + emitError(opLoc, "non-affine expression: right operand of ceildiv " + "has to be either a constant or symbolic"); + return nullptr; + } + return lhs.ceilDiv(rhs); + case Mod: + if (!rhs.isSymbolicOrConstant()) { + emitError(opLoc, "non-affine expression: right operand of mod " + "has to be either a constant or symbolic"); + return nullptr; + } + return lhs % rhs; + case HNoOp: + llvm_unreachable("can't create affine expression for null high prec op"); + return nullptr; + } + llvm_unreachable("Unknown AffineHighPrecOp"); +} + +/// Create an affine binary low precedence op expression (add, sub). +AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op, + AffineExpr lhs, AffineExpr rhs) { + switch (op) { + case AffineLowPrecOp::Add: + return lhs + rhs; + case AffineLowPrecOp::Sub: + return lhs - rhs; + case AffineLowPrecOp::LNoOp: + llvm_unreachable("can't create affine expression for null low prec op"); + return nullptr; + } + llvm_unreachable("Unknown AffineLowPrecOp"); +} + +/// Consume this token if it is a lower precedence affine op (there are only +/// two precedence levels). +AffineLowPrecOp AffineParser::consumeIfLowPrecOp() { + switch (getToken().getKind()) { + case Token::plus: + consumeToken(Token::plus); + return AffineLowPrecOp::Add; + case Token::minus: + consumeToken(Token::minus); + return AffineLowPrecOp::Sub; + default: + return AffineLowPrecOp::LNoOp; + } +} + +/// Consume this token if it is a higher precedence affine op (there are only +/// two precedence levels) +AffineHighPrecOp AffineParser::consumeIfHighPrecOp() { + switch (getToken().getKind()) { + case Token::star: + consumeToken(Token::star); + return Mul; + case Token::kw_floordiv: + consumeToken(Token::kw_floordiv); + return FloorDiv; + case Token::kw_ceildiv: + consumeToken(Token::kw_ceildiv); + return CeilDiv; + case Token::kw_mod: + consumeToken(Token::kw_mod); + return Mod; + default: + return HNoOp; + } +} + +/// Parse a high precedence op expression list: mul, div, and mod are high +/// precedence binary ops, i.e., parse a +/// expr_1 op_1 expr_2 op_2 ... expr_n +/// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod). +/// All affine binary ops are left associative. +/// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is +/// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is +/// null. llhsOpLoc is the location of the llhsOp token that will be used to +/// report an error for non-conforming expressions. +AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs, + AffineHighPrecOp llhsOp, + SMLoc llhsOpLoc) { + AffineExpr lhs = parseAffineOperandExpr(llhs); + if (!lhs) + return nullptr; + + // Found an LHS. Parse the remaining expression. + auto opLoc = getToken().getLoc(); + if (AffineHighPrecOp op = consumeIfHighPrecOp()) { + if (llhs) { + AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc); + if (!expr) + return nullptr; + return parseAffineHighPrecOpExpr(expr, op, opLoc); + } + // No LLHS, get RHS + return parseAffineHighPrecOpExpr(lhs, op, opLoc); + } + + // This is the last operand in this expression. + if (llhs) + return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc); + + // No llhs, 'lhs' itself is the expression. + return lhs; +} + +/// Parse an affine expression inside parentheses. +/// +/// affine-expr ::= `(` affine-expr `)` +AffineExpr AffineParser::parseParentheticalExpr() { + if (parseToken(Token::l_paren, "expected '('")) + return nullptr; + if (getToken().is(Token::r_paren)) + return (emitError("no expression inside parentheses"), nullptr); + + auto expr = parseAffineExpr(); + if (!expr) + return nullptr; + if (parseToken(Token::r_paren, "expected ')'")) + return nullptr; + + return expr; +} + +/// Parse the negation expression. +/// +/// affine-expr ::= `-` affine-expr +AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) { + if (parseToken(Token::minus, "expected '-'")) + return nullptr; + + AffineExpr operand = parseAffineOperandExpr(lhs); + // Since negation has the highest precedence of all ops (including high + // precedence ops) but lower than parentheses, we are only going to use + // parseAffineOperandExpr instead of parseAffineExpr here. + if (!operand) + // Extra error message although parseAffineOperandExpr would have + // complained. Leads to a better diagnostic. + return (emitError("missing operand of negation"), nullptr); + return (-1) * operand; +} + +/// Parse a bare id that may appear in an affine expression. +/// +/// affine-expr ::= bare-id +AffineExpr AffineParser::parseBareIdExpr() { + if (getToken().isNot(Token::bare_identifier)) + return (emitError("expected bare identifier"), nullptr); + + StringRef sRef = getTokenSpelling(); + for (auto entry : dimsAndSymbols) { + if (entry.first == sRef) { + consumeToken(Token::bare_identifier); + return entry.second; + } + } + + return (emitError("use of undeclared identifier"), nullptr); +} + +/// Parse an SSA id which may appear in an affine expression. +AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) { + if (!allowParsingSSAIds) + return (emitError("unexpected ssa identifier"), nullptr); + if (getToken().isNot(Token::percent_identifier)) + return (emitError("expected ssa identifier"), nullptr); + auto name = getTokenSpelling(); + // Check if we already parsed this SSA id. + for (auto entry : dimsAndSymbols) { + if (entry.first == name) { + consumeToken(Token::percent_identifier); + return entry.second; + } + } + // Parse the SSA id and add an AffineDim/SymbolExpr to represent it. + if (parseElement(isSymbol)) + return (emitError("failed to parse ssa identifier"), nullptr); + auto idExpr = isSymbol + ? getAffineSymbolExpr(numSymbolOperands++, getContext()) + : getAffineDimExpr(numDimOperands++, getContext()); + dimsAndSymbols.push_back({name, idExpr}); + return idExpr; +} + +AffineExpr AffineParser::parseSymbolSSAIdExpr() { + if (parseToken(Token::kw_symbol, "expected symbol keyword") || + parseToken(Token::l_paren, "expected '(' at start of SSA symbol")) + return nullptr; + AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true); + if (!symbolExpr) + return nullptr; + if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol")) + return nullptr; + return symbolExpr; +} + +/// Parse a positive integral constant appearing in an affine expression. +/// +/// affine-expr ::= integer-literal +AffineExpr AffineParser::parseIntegerExpr() { + auto val = getToken().getUInt64IntegerValue(); + if (!val.hasValue() || (int64_t)val.getValue() < 0) + return (emitError("constant too large for index"), nullptr); + + consumeToken(Token::integer); + return builder.getAffineConstantExpr((int64_t)val.getValue()); +} + +/// Parses an expression that can be a valid operand of an affine expression. +/// lhs: if non-null, lhs is an affine expression that is the lhs of a binary +/// operator, the rhs of which is being parsed. This is used to determine +/// whether an error should be emitted for a missing right operand. +// Eg: for an expression without parentheses (like i + j + k + l), each +// of the four identifiers is an operand. For i + j*k + l, j*k is not an +// operand expression, it's an op expression and will be parsed via +// parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and +// -l are valid operands that will be parsed by this function. +AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { + switch (getToken().getKind()) { + case Token::bare_identifier: + return parseBareIdExpr(); + case Token::kw_symbol: + return parseSymbolSSAIdExpr(); + case Token::percent_identifier: + return parseSSAIdExpr(/*isSymbol=*/false); + case Token::integer: + return parseIntegerExpr(); + case Token::l_paren: + return parseParentheticalExpr(); + case Token::minus: + return parseNegateExpression(lhs); + case Token::kw_ceildiv: + case Token::kw_floordiv: + case Token::kw_mod: + case Token::plus: + case Token::star: + if (lhs) + emitError("missing right operand of binary operator"); + else + emitError("missing left operand of binary operator"); + return nullptr; + default: + if (lhs) + emitError("missing right operand of binary operator"); + else + emitError("expected affine expression"); + return nullptr; + } +} + +/// Parse affine expressions that are bare-id's, integer constants, +/// parenthetical affine expressions, and affine op expressions that are a +/// composition of those. +/// +/// All binary op's associate from left to right. +/// +/// {add, sub} have lower precedence than {mul, div, and mod}. +/// +/// Add, sub'are themselves at the same precedence level. Mul, floordiv, +/// ceildiv, and mod are at the same higher precedence level. Negation has +/// higher precedence than any binary op. +/// +/// llhs: the affine expression appearing on the left of the one being parsed. +/// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null, +/// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned +/// if llhs is non-null; otherwise lhs is returned. This is to deal with left +/// associativity. +/// +/// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function +/// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where +/// (e2*e3) will be parsed using parseAffineHighPrecOpExpr(). +AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs, + AffineLowPrecOp llhsOp) { + AffineExpr lhs; + if (!(lhs = parseAffineOperandExpr(llhs))) + return nullptr; + + // Found an LHS. Deal with the ops. + if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) { + if (llhs) { + AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs); + return parseAffineLowPrecOpExpr(sum, lOp); + } + // No LLHS, get RHS and form the expression. + return parseAffineLowPrecOpExpr(lhs, lOp); + } + auto opLoc = getToken().getLoc(); + if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) { + // We have a higher precedence op here. Get the rhs operand for the llhs + // through parseAffineHighPrecOpExpr. + AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc); + if (!highRes) + return nullptr; + + // If llhs is null, the product forms the first operand of the yet to be + // found expression. If non-null, the op to associate with llhs is llhsOp. + AffineExpr expr = + llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes; + + // Recurse for subsequent low prec op's after the affine high prec op + // expression. + if (AffineLowPrecOp nextOp = consumeIfLowPrecOp()) + return parseAffineLowPrecOpExpr(expr, nextOp); + return expr; + } + // Last operand in the expression list. + if (llhs) + return getAffineBinaryOpExpr(llhsOp, llhs, lhs); + // No llhs, 'lhs' itself is the expression. + return lhs; +} + +/// Parse an affine expression. +/// affine-expr ::= `(` affine-expr `)` +/// | `-` affine-expr +/// | affine-expr `+` affine-expr +/// | affine-expr `-` affine-expr +/// | affine-expr `*` affine-expr +/// | affine-expr `floordiv` affine-expr +/// | affine-expr `ceildiv` affine-expr +/// | affine-expr `mod` affine-expr +/// | bare-id +/// | integer-literal +/// +/// Additional conditions are checked depending on the production. For eg., +/// one of the operands for `*` has to be either constant/symbolic; the second +/// operand for floordiv, ceildiv, and mod has to be a positive integer. +AffineExpr AffineParser::parseAffineExpr() { + return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp); +} + +/// Parse a dim or symbol from the lists appearing before the actual +/// expressions of the affine map. Update our state to store the +/// dimensional/symbolic identifier. +ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) { + if (getToken().isNot(Token::bare_identifier)) + return emitError("expected bare identifier"); + + auto name = getTokenSpelling(); + for (auto entry : dimsAndSymbols) { + if (entry.first == name) + return emitError("redefinition of identifier '" + name + "'"); + } + consumeToken(Token::bare_identifier); + + dimsAndSymbols.push_back({name, idExpr}); + return success(); +} + +/// Parse the list of dimensional identifiers to an affine map. +ParseResult AffineParser::parseDimIdList(unsigned &numDims) { + if (parseToken(Token::l_paren, + "expected '(' at start of dimensional identifiers list")) { + return failure(); + } + + auto parseElt = [&]() -> ParseResult { + auto dimension = getAffineDimExpr(numDims++, getContext()); + return parseIdentifierDefinition(dimension); + }; + return parseCommaSeparatedListUntil(Token::r_paren, parseElt); +} + +/// Parse the list of symbolic identifiers to an affine map. +ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) { + consumeToken(Token::l_square); + auto parseElt = [&]() -> ParseResult { + auto symbol = getAffineSymbolExpr(numSymbols++, getContext()); + return parseIdentifierDefinition(symbol); + }; + return parseCommaSeparatedListUntil(Token::r_square, parseElt); +} + +/// Parse the list of symbolic identifiers to an affine map. +ParseResult +AffineParser::parseDimAndOptionalSymbolIdList(unsigned &numDims, + unsigned &numSymbols) { + if (parseDimIdList(numDims)) { + return failure(); + } + if (!getToken().is(Token::l_square)) { + numSymbols = 0; + return success(); + } + return parseSymbolIdList(numSymbols); +} + +/// Parses an ambiguous affine map or integer set definition inline. +ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map, + IntegerSet &set) { + unsigned numDims = 0, numSymbols = 0; + + // List of dimensional and optional symbol identifiers. + if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) { + return failure(); + } + + // This is needed for parsing attributes as we wouldn't know whether we would + // be parsing an integer set attribute or an affine map attribute. + bool isArrow = getToken().is(Token::arrow); + bool isColon = getToken().is(Token::colon); + if (!isArrow && !isColon) { + return emitError("expected '->' or ':'"); + } else if (isArrow) { + parseToken(Token::arrow, "expected '->' or '['"); + map = parseAffineMapRange(numDims, numSymbols); + return map ? success() : failure(); + } else if (parseToken(Token::colon, "expected ':' or '['")) { + return failure(); + } + + if ((set = parseIntegerSetConstraints(numDims, numSymbols))) + return success(); + + return failure(); +} + +/// Parse an AffineMap where the dim and symbol identifiers are SSA ids. +ParseResult +AffineParser::parseAffineMapOfSSAIds(AffineMap &map, + OpAsmParser::Delimiter delimiter) { + Token::Kind rightToken; + switch (delimiter) { + case OpAsmParser::Delimiter::Square: + if (parseToken(Token::l_square, "expected '['")) + return failure(); + rightToken = Token::r_square; + break; + case OpAsmParser::Delimiter::Paren: + if (parseToken(Token::l_paren, "expected '('")) + return failure(); + rightToken = Token::r_paren; + break; + default: + return emitError("unexpected delimiter"); + } + + SmallVector exprs; + auto parseElt = [&]() -> ParseResult { + auto elt = parseAffineExpr(); + exprs.push_back(elt); + return elt ? success() : failure(); + }; + + // Parse a multi-dimensional affine expression (a comma-separated list of + // 1-d affine expressions); the list can be empty. Grammar: + // multi-dim-affine-expr ::= `(` `)` + // | `(` affine-expr (`,` affine-expr)* `)` + if (parseCommaSeparatedListUntil(rightToken, parseElt, + /*allowEmptyList=*/true)) + return failure(); + // Parsed a valid affine map. + map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands, + exprs, getContext()); + return success(); +} + +/// Parse the range and sizes affine map definition inline. +/// +/// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr +/// +/// multi-dim-affine-expr ::= `(` `)` +/// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)` +AffineMap AffineParser::parseAffineMapRange(unsigned numDims, + unsigned numSymbols) { + parseToken(Token::l_paren, "expected '(' at start of affine map range"); + + SmallVector exprs; + auto parseElt = [&]() -> ParseResult { + auto elt = parseAffineExpr(); + ParseResult res = elt ? success() : failure(); + exprs.push_back(elt); + return res; + }; + + // Parse a multi-dimensional affine expression (a comma-separated list of + // 1-d affine expressions). Grammar: + // multi-dim-affine-expr ::= `(` `)` + // | `(` affine-expr (`,` affine-expr)* `)` + if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) + return AffineMap(); + + // Parsed a valid affine map. + return AffineMap::get(numDims, numSymbols, exprs, getContext()); +} + +/// Parse an affine constraint. +/// affine-constraint ::= affine-expr `>=` `0` +/// | affine-expr `==` `0` +/// +/// isEq is set to true if the parsed constraint is an equality, false if it +/// is an inequality (greater than or equal). +/// +AffineExpr AffineParser::parseAffineConstraint(bool *isEq) { + AffineExpr expr = parseAffineExpr(); + if (!expr) + return nullptr; + + if (consumeIf(Token::greater) && consumeIf(Token::equal) && + getToken().is(Token::integer)) { + auto dim = getToken().getUnsignedIntegerValue(); + if (dim.hasValue() && dim.getValue() == 0) { + consumeToken(Token::integer); + *isEq = false; + return expr; + } + return (emitError("expected '0' after '>='"), nullptr); + } + + if (consumeIf(Token::equal) && consumeIf(Token::equal) && + getToken().is(Token::integer)) { + auto dim = getToken().getUnsignedIntegerValue(); + if (dim.hasValue() && dim.getValue() == 0) { + consumeToken(Token::integer); + *isEq = true; + return expr; + } + return (emitError("expected '0' after '=='"), nullptr); + } + + return (emitError("expected '== 0' or '>= 0' at end of affine constraint"), + nullptr); +} + +/// Parse the constraints that are part of an integer set definition. +/// integer-set-inline +/// ::= dim-and-symbol-id-lists `:` +/// '(' affine-constraint-conjunction? ')' +/// affine-constraint-conjunction ::= affine-constraint (`,` +/// affine-constraint)* +/// +IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, + unsigned numSymbols) { + if (parseToken(Token::l_paren, + "expected '(' at start of integer set constraint list")) + return IntegerSet(); + + SmallVector constraints; + SmallVector isEqs; + auto parseElt = [&]() -> ParseResult { + bool isEq; + auto elt = parseAffineConstraint(&isEq); + ParseResult res = elt ? success() : failure(); + if (elt) { + constraints.push_back(elt); + isEqs.push_back(isEq); + } + return res; + }; + + // Parse a list of affine constraints (comma-separated). + if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) + return IntegerSet(); + + // If no constraints were parsed, then treat this as a degenerate 'true' case. + if (constraints.empty()) { + /* 0 == 0 */ + auto zero = getAffineConstantExpr(0, getContext()); + return IntegerSet::get(numDims, numSymbols, zero, true); + } + + // Parsed a valid integer set. + return IntegerSet::get(numDims, numSymbols, constraints, isEqs); +} + +//===----------------------------------------------------------------------===// +// Parser +//===----------------------------------------------------------------------===// + +/// Parse an ambiguous reference to either and affine map or an integer set. +ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map, + IntegerSet &set) { + return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set); +} +ParseResult Parser::parseAffineMapReference(AffineMap &map) { + llvm::SMLoc curLoc = getToken().getLoc(); + IntegerSet set; + if (parseAffineMapOrIntegerSetReference(map, set)) + return failure(); + if (set) + return emitError(curLoc, "expected AffineMap, but got IntegerSet"); + return success(); +} +ParseResult Parser::parseIntegerSetReference(IntegerSet &set) { + llvm::SMLoc curLoc = getToken().getLoc(); + AffineMap map; + if (parseAffineMapOrIntegerSetReference(map, set)) + return failure(); + if (map) + return emitError(curLoc, "expected IntegerSet, but got AffineMap"); + return success(); +} + +/// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to +/// parse SSA value uses encountered while parsing affine expressions. +ParseResult +Parser::parseAffineMapOfSSAIds(AffineMap &map, + function_ref parseElement, + OpAsmParser::Delimiter delimiter) { + return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement) + .parseAffineMapOfSSAIds(map, delimiter); +} diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -0,0 +1,914 @@ +//===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the parser for the MLIR Types. +// +//===----------------------------------------------------------------------===// + +#include "Parser.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/StandardTypes.h" +#include "llvm/ADT/StringExtras.h" + +using namespace mlir; +using namespace mlir::detail; + +/// Parse an arbitrary attribute. +/// +/// attribute-value ::= `unit` +/// | bool-literal +/// | integer-literal (`:` (index-type | integer-type))? +/// | float-literal (`:` float-type)? +/// | string-literal (`:` type)? +/// | type +/// | `[` (attribute-value (`,` attribute-value)*)? `]` +/// | `{` (attribute-entry (`,` attribute-entry)*)? `}` +/// | symbol-ref-id (`::` symbol-ref-id)* +/// | `dense` `<` attribute-value `>` `:` +/// (tensor-type | vector-type) +/// | `sparse` `<` attribute-value `,` attribute-value `>` +/// `:` (tensor-type | vector-type) +/// | `opaque` `<` dialect-namespace `,` hex-string-literal +/// `>` `:` (tensor-type | vector-type) +/// | extended-attribute +/// +Attribute Parser::parseAttribute(Type type) { + switch (getToken().getKind()) { + // Parse an AffineMap or IntegerSet attribute. + case Token::kw_affine_map: { + consumeToken(Token::kw_affine_map); + + AffineMap map; + if (parseToken(Token::less, "expected '<' in affine map") || + parseAffineMapReference(map) || + parseToken(Token::greater, "expected '>' in affine map")) + return Attribute(); + return AffineMapAttr::get(map); + } + case Token::kw_affine_set: { + consumeToken(Token::kw_affine_set); + + IntegerSet set; + if (parseToken(Token::less, "expected '<' in integer set") || + parseIntegerSetReference(set) || + parseToken(Token::greater, "expected '>' in integer set")) + return Attribute(); + return IntegerSetAttr::get(set); + } + + // Parse an array attribute. + case Token::l_square: { + consumeToken(Token::l_square); + + SmallVector elements; + auto parseElt = [&]() -> ParseResult { + elements.push_back(parseAttribute()); + return elements.back() ? success() : failure(); + }; + + if (parseCommaSeparatedListUntil(Token::r_square, parseElt)) + return nullptr; + return builder.getArrayAttr(elements); + } + + // Parse a boolean attribute. + case Token::kw_false: + consumeToken(Token::kw_false); + return builder.getBoolAttr(false); + case Token::kw_true: + consumeToken(Token::kw_true); + return builder.getBoolAttr(true); + + // Parse a dense elements attribute. + case Token::kw_dense: + return parseDenseElementsAttr(type); + + // Parse a dictionary attribute. + case Token::l_brace: { + NamedAttrList elements; + if (parseAttributeDict(elements)) + return nullptr; + return elements.getDictionary(getContext()); + } + + // Parse an extended attribute, i.e. alias or dialect attribute. + case Token::hash_identifier: + return parseExtendedAttr(type); + + // Parse floating point and integer attributes. + case Token::floatliteral: + return parseFloatAttr(type, /*isNegative=*/false); + case Token::integer: + return parseDecOrHexAttr(type, /*isNegative=*/false); + case Token::minus: { + consumeToken(Token::minus); + if (getToken().is(Token::integer)) + return parseDecOrHexAttr(type, /*isNegative=*/true); + if (getToken().is(Token::floatliteral)) + return parseFloatAttr(type, /*isNegative=*/true); + + return (emitError("expected constant integer or floating point value"), + nullptr); + } + + // Parse a location attribute. + case Token::kw_loc: { + LocationAttr attr; + return failed(parseLocation(attr)) ? Attribute() : attr; + } + + // Parse an opaque elements attribute. + case Token::kw_opaque: + return parseOpaqueElementsAttr(type); + + // Parse a sparse elements attribute. + case Token::kw_sparse: + return parseSparseElementsAttr(type); + + // Parse a string attribute. + case Token::string: { + auto val = getToken().getStringValue(); + consumeToken(Token::string); + // Parse the optional trailing colon type if one wasn't explicitly provided. + if (!type && consumeIf(Token::colon) && !(type = parseType())) + return Attribute(); + + return type ? StringAttr::get(val, type) + : StringAttr::get(val, getContext()); + } + + // Parse a symbol reference attribute. + case Token::at_identifier: { + std::string nameStr = getToken().getSymbolReference(); + consumeToken(Token::at_identifier); + + // Parse any nested references. + std::vector nestedRefs; + while (getToken().is(Token::colon)) { + // Check for the '::' prefix. + const char *curPointer = getToken().getLoc().getPointer(); + consumeToken(Token::colon); + if (!consumeIf(Token::colon)) { + state.lex.resetPointer(curPointer); + consumeToken(); + break; + } + // Parse the reference itself. + auto curLoc = getToken().getLoc(); + if (getToken().isNot(Token::at_identifier)) { + emitError(curLoc, "expected nested symbol reference identifier"); + return Attribute(); + } + + std::string nameStr = getToken().getSymbolReference(); + consumeToken(Token::at_identifier); + nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext())); + } + + return builder.getSymbolRefAttr(nameStr, nestedRefs); + } + + // Parse a 'unit' attribute. + case Token::kw_unit: + consumeToken(Token::kw_unit); + return builder.getUnitAttr(); + + default: + // Parse a type attribute. + if (Type type = parseType()) + return TypeAttr::get(type); + return nullptr; + } +} + +/// Attribute dictionary. +/// +/// attribute-dict ::= `{` `}` +/// | `{` attribute-entry (`,` attribute-entry)* `}` +/// attribute-entry ::= (bare-id | string-literal) `=` attribute-value +/// +ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) { + if (parseToken(Token::l_brace, "expected '{' in attribute dictionary")) + return failure(); + + llvm::SmallDenseSet seenKeys; + auto parseElt = [&]() -> ParseResult { + // The name of an attribute can either be a bare identifier, or a string. + Optional nameId; + if (getToken().is(Token::string)) + nameId = builder.getIdentifier(getToken().getStringValue()); + else if (getToken().isAny(Token::bare_identifier, Token::inttype) || + getToken().isKeyword()) + nameId = builder.getIdentifier(getTokenSpelling()); + else + return emitError("expected attribute name"); + if (!seenKeys.insert(*nameId).second) + return emitError("duplicate key in dictionary attribute"); + consumeToken(); + + // Try to parse the '=' for the attribute value. + if (!consumeIf(Token::equal)) { + // If there is no '=', we treat this as a unit attribute. + attributes.push_back({*nameId, builder.getUnitAttr()}); + return success(); + } + + auto attr = parseAttribute(); + if (!attr) + return failure(); + attributes.push_back({*nameId, attr}); + return success(); + }; + + if (parseCommaSeparatedListUntil(Token::r_brace, parseElt)) + return failure(); + + return success(); +} + +/// Parse a float attribute. +Attribute Parser::parseFloatAttr(Type type, bool isNegative) { + auto val = getToken().getFloatingPointValue(); + if (!val.hasValue()) + return (emitError("floating point value too large for attribute"), nullptr); + consumeToken(Token::floatliteral); + if (!type) { + // Default to F64 when no type is specified. + if (!consumeIf(Token::colon)) + type = builder.getF64Type(); + else if (!(type = parseType())) + return nullptr; + } + if (!type.isa()) + return (emitError("floating point value not valid for specified type"), + nullptr); + return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue()); +} + +/// Construct a float attribute bitwise equivalent to the integer literal. +static Optional buildHexadecimalFloatLiteral(Parser *p, FloatType type, + uint64_t value) { + // FIXME: bfloat is currently stored as a double internally because it doesn't + // have valid APFloat semantics. + if (type.isF64() || type.isBF16()) + return APFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value)); + + APInt apInt(type.getWidth(), value); + if (apInt != value) { + p->emitError("hexadecimal float constant out of range for type"); + return llvm::None; + } + return APFloat(type.getFloatSemantics(), apInt); +} + +/// Construct an APint from a parsed value, a known attribute type and +/// sign. +static Optional buildAttributeAPInt(Type type, bool isNegative, + StringRef spelling) { + // Parse the integer value into an APInt that is big enough to hold the value. + APInt result; + bool isHex = spelling.size() > 1 && spelling[1] == 'x'; + if (spelling.getAsInteger(isHex ? 0 : 10, result)) + return llvm::None; + + // Extend or truncate the bitwidth to the right size. + unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth + : type.getIntOrFloatBitWidth(); + if (width > result.getBitWidth()) { + result = result.zext(width); + } else if (width < result.getBitWidth()) { + // The parser can return an unnecessarily wide result with leading zeros. + // This isn't a problem, but truncating off bits is bad. + if (result.countLeadingZeros() < result.getBitWidth() - width) + return llvm::None; + + result = result.trunc(width); + } + + if (isNegative) { + // The value is negative, we have an overflow if the sign bit is not set + // in the negated apInt. + result.negate(); + if (!result.isSignBitSet()) + return llvm::None; + } else if ((type.isSignedInteger() || type.isIndex()) && + result.isSignBitSet()) { + // The value is a positive signed integer or index, + // we have an overflow if the sign bit is set. + return llvm::None; + } + + return result; +} + +/// Parse a decimal or a hexadecimal literal, which can be either an integer +/// or a float attribute. +Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { + // Remember if the literal is hexadecimal. + StringRef spelling = getToken().getSpelling(); + auto loc = state.curToken.getLoc(); + bool isHex = spelling.size() > 1 && spelling[1] == 'x'; + + consumeToken(Token::integer); + if (!type) { + // Default to i64 if not type is specified. + if (!consumeIf(Token::colon)) + type = builder.getIntegerType(64); + else if (!(type = parseType())) + return nullptr; + } + + if (auto floatType = type.dyn_cast()) { + if (isNegative) + return emitError( + loc, + "hexadecimal float literal should not have a leading minus"), + nullptr; + if (!isHex) { + emitError(loc, "unexpected decimal integer literal for a float attribute") + .attachNote() + << "add a trailing dot to make the literal a float"; + return nullptr; + } + + auto val = Token::getUInt64IntegerValue(spelling); + if (!val.hasValue()) + return emitError("integer constant out of range for attribute"), nullptr; + + // Construct a float attribute bitwise equivalent to the integer literal. + Optional apVal = + buildHexadecimalFloatLiteral(this, floatType, *val); + return apVal ? FloatAttr::get(floatType, *apVal) : Attribute(); + } + + if (!type.isa() && !type.isa()) + return emitError(loc, "integer literal not valid for specified type"), + nullptr; + + if (isNegative && type.isUnsignedInteger()) { + emitError(loc, + "negative integer literal not valid for unsigned integer type"); + return nullptr; + } + + Optional apInt = buildAttributeAPInt(type, isNegative, spelling); + if (!apInt) + return emitError(loc, "integer constant out of range for attribute"), + nullptr; + return builder.getIntegerAttr(type, *apInt); +} + +//===----------------------------------------------------------------------===// +// TensorLiteralParser +//===----------------------------------------------------------------------===// + +/// Parse elements values stored within a hex etring. On success, the values are +/// stored into 'result'. +static ParseResult parseElementAttrHexValues(Parser &parser, Token tok, + std::string &result) { + std::string val = tok.getStringValue(); + if (val.size() < 2 || val[0] != '0' || val[1] != 'x') + return parser.emitError(tok.getLoc(), + "elements hex string should start with '0x'"); + + StringRef hexValues = StringRef(val).drop_front(2); + if (!llvm::all_of(hexValues, llvm::isHexDigit)) + return parser.emitError(tok.getLoc(), + "elements hex string only contains hex digits"); + + result = llvm::fromHex(hexValues); + return success(); +} + +namespace { + +/// This class implements a parser for TensorLiterals. A tensor literal is +/// either a single element (e.g, 5) or a multi-dimensional list of elements +/// (e.g., [[5, 5]]). +class TensorLiteralParser { +public: + TensorLiteralParser(Parser &p) : p(p) {} + + /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser + /// may also parse a tensor literal that is store as a hex string. + ParseResult parse(bool allowHex); + + /// Build a dense attribute instance with the parsed elements and the given + /// shaped type. + DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type); + + ArrayRef getShape() const { return shape; } + +private: + /// Get the parsed elements for an integer attribute. + ParseResult getIntAttrElements(llvm::SMLoc loc, Type eltTy, + std::vector &intValues); + + /// Get the parsed elements for a float attribute. + ParseResult getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy, + std::vector &floatValues); + + /// Build a Dense String attribute for the given type. + DenseElementsAttr getStringAttr(llvm::SMLoc loc, ShapedType type, Type eltTy); + + /// Build a Dense attribute with hex data for the given type. + DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type); + + /// Parse a single element, returning failure if it isn't a valid element + /// literal. For example: + /// parseElement(1) -> Success, 1 + /// parseElement([1]) -> Failure + ParseResult parseElement(); + + /// Parse a list of either lists or elements, returning the dimensions of the + /// parsed sub-tensors in dims. For example: + /// parseList([1, 2, 3]) -> Success, [3] + /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] + /// parseList([[1, 2], 3]) -> Failure + /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure + ParseResult parseList(SmallVectorImpl &dims); + + /// Parse a literal that was printed as a hex string. + ParseResult parseHexElements(); + + Parser &p; + + /// The shape inferred from the parsed elements. + SmallVector shape; + + /// Storage used when parsing elements, this is a pair of . + std::vector> storage; + + /// Storage used when parsing elements that were stored as hex values. + Optional hexStorage; +}; +} // end anonymous namespace + +/// Parse the elements of a tensor literal. If 'allowHex' is true, the parser +/// may also parse a tensor literal that is store as a hex string. +ParseResult TensorLiteralParser::parse(bool allowHex) { + // If hex is allowed, check for a string literal. + if (allowHex && p.getToken().is(Token::string)) { + hexStorage = p.getToken(); + p.consumeToken(Token::string); + return success(); + } + // Otherwise, parse a list or an individual element. + if (p.getToken().is(Token::l_square)) + return parseList(shape); + return parseElement(); +} + +/// Build a dense attribute instance with the parsed elements and the given +/// shaped type. +DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc, + ShapedType type) { + Type eltType = type.getElementType(); + + // Check to see if we parse the literal from a hex string. + if (hexStorage.hasValue() && + (eltType.isIntOrFloat() || eltType.isa())) + return getHexAttr(loc, type); + + // Check that the parsed storage size has the same number of elements to the + // type, or is a known splat. + if (!shape.empty() && getShape() != type.getShape()) { + p.emitError(loc) << "inferred shape of elements literal ([" << getShape() + << "]) does not match type ([" << type.getShape() << "])"; + return nullptr; + } + + // Handle complex types in the specific element type cases below. + bool isComplex = false; + if (ComplexType complexTy = eltType.dyn_cast()) { + eltType = complexTy.getElementType(); + isComplex = true; + } + + // Handle integer and index types. + if (eltType.isIntOrIndex()) { + std::vector intValues; + if (failed(getIntAttrElements(loc, eltType, intValues))) + return nullptr; + if (isComplex) { + // If this is a complex, treat the parsed values as complex values. + auto complexData = llvm::makeArrayRef( + reinterpret_cast *>(intValues.data()), + intValues.size() / 2); + return DenseElementsAttr::get(type, complexData); + } + return DenseElementsAttr::get(type, intValues); + } + // Handle floating point types. + if (FloatType floatTy = eltType.dyn_cast()) { + std::vector floatValues; + if (failed(getFloatAttrElements(loc, floatTy, floatValues))) + return nullptr; + if (isComplex) { + // If this is a complex, treat the parsed values as complex values. + auto complexData = llvm::makeArrayRef( + reinterpret_cast *>(floatValues.data()), + floatValues.size() / 2); + return DenseElementsAttr::get(type, complexData); + } + return DenseElementsAttr::get(type, floatValues); + } + + // Other types are assumed to be string representations. + return getStringAttr(loc, type, type.getElementType()); +} + +/// Build a Dense Integer attribute for the given type. +ParseResult +TensorLiteralParser::getIntAttrElements(llvm::SMLoc loc, Type eltTy, + std::vector &intValues) { + intValues.reserve(storage.size()); + bool isUintType = eltTy.isUnsignedInteger(); + for (const auto &signAndToken : storage) { + bool isNegative = signAndToken.first; + const Token &token = signAndToken.second; + auto tokenLoc = token.getLoc(); + + if (isNegative && isUintType) { + return p.emitError(tokenLoc) + << "expected unsigned integer elements, but parsed negative value"; + } + + // Check to see if floating point values were parsed. + if (token.is(Token::floatliteral)) { + return p.emitError(tokenLoc) + << "expected integer elements, but parsed floating-point"; + } + + assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && + "unexpected token type"); + if (token.isAny(Token::kw_true, Token::kw_false)) { + if (!eltTy.isInteger(1)) { + return p.emitError(tokenLoc) + << "expected i1 type for 'true' or 'false' values"; + } + APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false); + intValues.push_back(apInt); + continue; + } + + // Create APInt values for each element with the correct bitwidth. + Optional apInt = + buildAttributeAPInt(eltTy, isNegative, token.getSpelling()); + if (!apInt) + return p.emitError(tokenLoc, "integer constant out of range for type"); + intValues.push_back(*apInt); + } + return success(); +} + +/// Build a Dense Float attribute for the given type. +ParseResult +TensorLiteralParser::getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy, + std::vector &floatValues) { + floatValues.reserve(storage.size()); + for (const auto &signAndToken : storage) { + bool isNegative = signAndToken.first; + const Token &token = signAndToken.second; + + // Handle hexadecimal float literals. + if (token.is(Token::integer) && token.getSpelling().startswith("0x")) { + if (isNegative) { + return p.emitError(token.getLoc()) + << "hexadecimal float literal should not have a leading minus"; + } + auto val = token.getUInt64IntegerValue(); + if (!val.hasValue()) { + return p.emitError( + "hexadecimal float constant out of range for attribute"); + } + Optional apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val); + if (!apVal) + return failure(); + floatValues.push_back(*apVal); + continue; + } + + // Check to see if any decimal integers or booleans were parsed. + if (!token.is(Token::floatliteral)) + return p.emitError() + << "expected floating-point elements, but parsed integer"; + + // Build the float values from tokens. + auto val = token.getFloatingPointValue(); + if (!val.hasValue()) + return p.emitError("floating point value too large for attribute"); + + // Treat BF16 as double because it is not supported in LLVM's APFloat. + APFloat apVal(isNegative ? -*val : *val); + if (!eltTy.isBF16() && !eltTy.isF64()) { + bool unused; + apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven, + &unused); + } + floatValues.push_back(apVal); + } + return success(); +} + +/// Build a Dense String attribute for the given type. +DenseElementsAttr TensorLiteralParser::getStringAttr(llvm::SMLoc loc, + ShapedType type, + Type eltTy) { + if (hexStorage.hasValue()) { + auto stringValue = hexStorage.getValue().getStringValue(); + return DenseStringElementsAttr::get(type, {stringValue}); + } + + std::vector stringValues; + std::vector stringRefValues; + stringValues.reserve(storage.size()); + stringRefValues.reserve(storage.size()); + + for (auto val : storage) { + stringValues.push_back(val.second.getStringValue()); + stringRefValues.push_back(stringValues.back()); + } + + return DenseStringElementsAttr::get(type, stringRefValues); +} + +/// Build a Dense attribute with hex data for the given type. +DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc, + ShapedType type) { + Type elementType = type.getElementType(); + if (!elementType.isIntOrIndexOrFloat() && !elementType.isa()) { + p.emitError(loc) + << "expected floating-point, integer, or complex element type, got " + << elementType; + return nullptr; + } + + std::string data; + if (parseElementAttrHexValues(p, hexStorage.getValue(), data)) + return nullptr; + + ArrayRef rawData(data.data(), data.size()); + bool detectedSplat = false; + if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) { + p.emitError(loc) << "elements hex data size is invalid for provided type: " + << type; + return nullptr; + } + + return DenseElementsAttr::getFromRawBuffer(type, rawData, detectedSplat); +} + +ParseResult TensorLiteralParser::parseElement() { + switch (p.getToken().getKind()) { + // Parse a boolean element. + case Token::kw_true: + case Token::kw_false: + case Token::floatliteral: + case Token::integer: + storage.emplace_back(/*isNegative=*/false, p.getToken()); + p.consumeToken(); + break; + + // Parse a signed integer or a negative floating-point element. + case Token::minus: + p.consumeToken(Token::minus); + if (!p.getToken().isAny(Token::floatliteral, Token::integer)) + return p.emitError("expected integer or floating point literal"); + storage.emplace_back(/*isNegative=*/true, p.getToken()); + p.consumeToken(); + break; + + case Token::string: + storage.emplace_back(/*isNegative=*/false, p.getToken()); + p.consumeToken(); + break; + + // Parse a complex element of the form '(' element ',' element ')'. + case Token::l_paren: + p.consumeToken(Token::l_paren); + if (parseElement() || + p.parseToken(Token::comma, "expected ',' between complex elements") || + parseElement() || + p.parseToken(Token::r_paren, "expected ')' after complex elements")) + return failure(); + break; + + default: + return p.emitError("expected element literal of primitive type"); + } + + return success(); +} + +/// Parse a list of either lists or elements, returning the dimensions of the +/// parsed sub-tensors in dims. For example: +/// parseList([1, 2, 3]) -> Success, [3] +/// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] +/// parseList([[1, 2], 3]) -> Failure +/// parseList([[1, [2, 3]], [4, [5]]]) -> Failure +ParseResult TensorLiteralParser::parseList(SmallVectorImpl &dims) { + p.consumeToken(Token::l_square); + + auto checkDims = [&](const SmallVectorImpl &prevDims, + const SmallVectorImpl &newDims) -> ParseResult { + if (prevDims == newDims) + return success(); + return p.emitError("tensor literal is invalid; ranks are not consistent " + "between elements"); + }; + + bool first = true; + SmallVector newDims; + unsigned size = 0; + auto parseCommaSeparatedList = [&]() -> ParseResult { + SmallVector thisDims; + if (p.getToken().getKind() == Token::l_square) { + if (parseList(thisDims)) + return failure(); + } else if (parseElement()) { + return failure(); + } + ++size; + if (!first) + return checkDims(newDims, thisDims); + newDims = thisDims; + first = false; + return success(); + }; + if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList)) + return failure(); + + // Return the sublists' dimensions with 'size' prepended. + dims.clear(); + dims.push_back(size); + dims.append(newDims.begin(), newDims.end()); + return success(); +} + +//===----------------------------------------------------------------------===// +// ElementsAttr Parser +//===----------------------------------------------------------------------===// + +/// Parse a dense elements attribute. +Attribute Parser::parseDenseElementsAttr(Type attrType) { + consumeToken(Token::kw_dense); + if (parseToken(Token::less, "expected '<' after 'dense'")) + return nullptr; + + // Parse the literal data. + TensorLiteralParser literalParser(*this); + if (literalParser.parse(/*allowHex=*/true)) + return nullptr; + + if (parseToken(Token::greater, "expected '>'")) + return nullptr; + + auto typeLoc = getToken().getLoc(); + auto type = parseElementsLiteralType(attrType); + if (!type) + return nullptr; + return literalParser.getAttr(typeLoc, type); +} + +/// Parse an opaque elements attribute. +Attribute Parser::parseOpaqueElementsAttr(Type attrType) { + consumeToken(Token::kw_opaque); + if (parseToken(Token::less, "expected '<' after 'opaque'")) + return nullptr; + + if (getToken().isNot(Token::string)) + return (emitError("expected dialect namespace"), nullptr); + + auto name = getToken().getStringValue(); + auto *dialect = builder.getContext()->getRegisteredDialect(name); + // TODO(shpeisman): Allow for having an unknown dialect on an opaque + // attribute. Otherwise, it can't be roundtripped without having the dialect + // registered. + if (!dialect) + return (emitError("no registered dialect with namespace '" + name + "'"), + nullptr); + consumeToken(Token::string); + + if (parseToken(Token::comma, "expected ','")) + return nullptr; + + Token hexTok = getToken(); + if (parseToken(Token::string, "elements hex string should start with '0x'") || + parseToken(Token::greater, "expected '>'")) + return nullptr; + auto type = parseElementsLiteralType(attrType); + if (!type) + return nullptr; + + std::string data; + if (parseElementAttrHexValues(*this, hexTok, data)) + return nullptr; + return OpaqueElementsAttr::get(dialect, type, data); +} + +/// Shaped type for elements attribute. +/// +/// elements-literal-type ::= vector-type | ranked-tensor-type +/// +/// This method also checks the type has static shape. +ShapedType Parser::parseElementsLiteralType(Type type) { + // If the user didn't provide a type, parse the colon type for the literal. + if (!type) { + if (parseToken(Token::colon, "expected ':'")) + return nullptr; + if (!(type = parseType())) + return nullptr; + } + + if (!type.isa() && !type.isa()) { + emitError("elements literal must be a ranked tensor or vector type"); + return nullptr; + } + + auto sType = type.cast(); + if (!sType.hasStaticShape()) + return (emitError("elements literal type must have static shape"), nullptr); + + return sType; +} + +/// Parse a sparse elements attribute. +Attribute Parser::parseSparseElementsAttr(Type attrType) { + consumeToken(Token::kw_sparse); + if (parseToken(Token::less, "Expected '<' after 'sparse'")) + return nullptr; + + /// Parse the indices. We don't allow hex values here as we may need to use + /// the inferred shape. + auto indicesLoc = getToken().getLoc(); + TensorLiteralParser indiceParser(*this); + if (indiceParser.parse(/*allowHex=*/false)) + return nullptr; + + if (parseToken(Token::comma, "expected ','")) + return nullptr; + + /// Parse the values. + auto valuesLoc = getToken().getLoc(); + TensorLiteralParser valuesParser(*this); + if (valuesParser.parse(/*allowHex=*/true)) + return nullptr; + + if (parseToken(Token::greater, "expected '>'")) + return nullptr; + + auto type = parseElementsLiteralType(attrType); + if (!type) + return nullptr; + + // If the indices are a splat, i.e. the literal parser parsed an element and + // not a list, we set the shape explicitly. The indices are represented by a + // 2-dimensional shape where the second dimension is the rank of the type. + // Given that the parsed indices is a splat, we know that we only have one + // indice and thus one for the first dimension. + auto indiceEltType = builder.getIntegerType(64); + ShapedType indicesType; + if (indiceParser.getShape().empty()) { + indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); + } else { + // Otherwise, set the shape to the one parsed by the literal parser. + indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); + } + auto indices = indiceParser.getAttr(indicesLoc, indicesType); + + // If the values are a splat, set the shape explicitly based on the number of + // indices. The number of indices is encoded in the first dimension of the + // indice shape type. + auto valuesEltType = type.getElementType(); + ShapedType valuesType = + valuesParser.getShape().empty() + ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) + : RankedTensorType::get(valuesParser.getShape(), valuesEltType); + auto values = valuesParser.getAttr(valuesLoc, valuesType); + + /// Sanity check. + if (valuesType.getRank() != 1) + return (emitError("expected 1-d tensor for values"), nullptr); + + auto sameShape = (indicesType.getRank() == 1) || + (type.getRank() == indicesType.getDimSize(1)); + auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0); + if (!sameShape || !sameElementNum) { + emitError() << "expected shape ([" << type.getShape() + << "]); inferred shape of indices literal ([" + << indicesType.getShape() + << "]); inferred shape of values literal ([" + << valuesType.getShape() << "])"; + return nullptr; + } + + // Build the sparse elements attribute by the indices and values. + return SparseElementsAttr::get(type, indices, values); +} diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Parser/DialectSymbolParser.cpp @@ -0,0 +1,612 @@ +//===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the parser for the dialect symbols, such as extended +// attributes and types. +// +//===----------------------------------------------------------------------===// + +#include "Parser.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/StandardTypes.h" +#include "llvm/Support/SourceMgr.h" + +using namespace mlir; +using namespace mlir::detail; +using llvm::MemoryBuffer; +using llvm::SMLoc; +using llvm::SourceMgr; + +namespace { +/// This class provides the main implementation of the DialectAsmParser that +/// allows for dialects to parse attributes and types. This allows for dialect +/// hooking into the main MLIR parsing logic. +class CustomDialectAsmParser : public DialectAsmParser { +public: + CustomDialectAsmParser(StringRef fullSpec, Parser &parser) + : fullSpec(fullSpec), nameLoc(parser.getToken().getLoc()), + parser(parser) {} + ~CustomDialectAsmParser() override {} + + /// Emit a diagnostic at the specified location and return failure. + InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { + return parser.emitError(loc, message); + } + + /// Return a builder which provides useful access to MLIRContext, global + /// objects like types and attributes. + Builder &getBuilder() const override { return parser.builder; } + + /// Get the location of the next token and store it into the argument. This + /// always succeeds. + llvm::SMLoc getCurrentLocation() override { + return parser.getToken().getLoc(); + } + + /// Return the location of the original name token. + llvm::SMLoc getNameLoc() const override { return nameLoc; } + + /// Re-encode the given source location as an MLIR location and return it. + Location getEncodedSourceLoc(llvm::SMLoc loc) override { + return parser.getEncodedSourceLocation(loc); + } + + /// Returns the full specification of the symbol being parsed. This allows + /// for using a separate parser if necessary. + StringRef getFullSymbolSpec() const override { return fullSpec; } + + /// Parse a floating point value from the stream. + ParseResult parseFloat(double &result) override { + bool negative = parser.consumeIf(Token::minus); + Token curTok = parser.getToken(); + + // Check for a floating point value. + if (curTok.is(Token::floatliteral)) { + auto val = curTok.getFloatingPointValue(); + if (!val.hasValue()) + return emitError(curTok.getLoc(), "floating point value too large"); + parser.consumeToken(Token::floatliteral); + result = negative ? -*val : *val; + return success(); + } + + // TODO(riverriddle) support hex floating point values. + return emitError(getCurrentLocation(), "expected floating point literal"); + } + + /// Parse an optional integer value from the stream. + OptionalParseResult parseOptionalInteger(uint64_t &result) override { + Token curToken = parser.getToken(); + if (curToken.isNot(Token::integer, Token::minus)) + return llvm::None; + + bool negative = parser.consumeIf(Token::minus); + Token curTok = parser.getToken(); + if (parser.parseToken(Token::integer, "expected integer value")) + return failure(); + + auto val = curTok.getUInt64IntegerValue(); + if (!val) + return emitError(curTok.getLoc(), "integer value too large"); + result = negative ? -*val : *val; + return success(); + } + + //===--------------------------------------------------------------------===// + // Token Parsing + //===--------------------------------------------------------------------===// + + /// Parse a `->` token. + ParseResult parseArrow() override { + return parser.parseToken(Token::arrow, "expected '->'"); + } + + /// Parses a `->` if present. + ParseResult parseOptionalArrow() override { + return success(parser.consumeIf(Token::arrow)); + } + + /// Parse a '{' token. + ParseResult parseLBrace() override { + return parser.parseToken(Token::l_brace, "expected '{'"); + } + + /// Parse a '{' token if present + ParseResult parseOptionalLBrace() override { + return success(parser.consumeIf(Token::l_brace)); + } + + /// Parse a `}` token. + ParseResult parseRBrace() override { + return parser.parseToken(Token::r_brace, "expected '}'"); + } + + /// Parse a `}` token if present + ParseResult parseOptionalRBrace() override { + return success(parser.consumeIf(Token::r_brace)); + } + + /// Parse a `:` token. + ParseResult parseColon() override { + return parser.parseToken(Token::colon, "expected ':'"); + } + + /// Parse a `:` token if present. + ParseResult parseOptionalColon() override { + return success(parser.consumeIf(Token::colon)); + } + + /// Parse a `,` token. + ParseResult parseComma() override { + return parser.parseToken(Token::comma, "expected ','"); + } + + /// Parse a `,` token if present. + ParseResult parseOptionalComma() override { + return success(parser.consumeIf(Token::comma)); + } + + /// Parses a `...` if present. + ParseResult parseOptionalEllipsis() override { + return success(parser.consumeIf(Token::ellipsis)); + } + + /// Parse a `=` token. + ParseResult parseEqual() override { + return parser.parseToken(Token::equal, "expected '='"); + } + + /// Parse a '<' token. + ParseResult parseLess() override { + return parser.parseToken(Token::less, "expected '<'"); + } + + /// Parse a `<` token if present. + ParseResult parseOptionalLess() override { + return success(parser.consumeIf(Token::less)); + } + + /// Parse a '>' token. + ParseResult parseGreater() override { + return parser.parseToken(Token::greater, "expected '>'"); + } + + /// Parse a `>` token if present. + ParseResult parseOptionalGreater() override { + return success(parser.consumeIf(Token::greater)); + } + + /// Parse a `(` token. + ParseResult parseLParen() override { + return parser.parseToken(Token::l_paren, "expected '('"); + } + + /// Parses a '(' if present. + ParseResult parseOptionalLParen() override { + return success(parser.consumeIf(Token::l_paren)); + } + + /// Parse a `)` token. + ParseResult parseRParen() override { + return parser.parseToken(Token::r_paren, "expected ')'"); + } + + /// Parses a ')' if present. + ParseResult parseOptionalRParen() override { + return success(parser.consumeIf(Token::r_paren)); + } + + /// Parse a `[` token. + ParseResult parseLSquare() override { + return parser.parseToken(Token::l_square, "expected '['"); + } + + /// Parses a '[' if present. + ParseResult parseOptionalLSquare() override { + return success(parser.consumeIf(Token::l_square)); + } + + /// Parse a `]` token. + ParseResult parseRSquare() override { + return parser.parseToken(Token::r_square, "expected ']'"); + } + + /// Parses a ']' if present. + ParseResult parseOptionalRSquare() override { + return success(parser.consumeIf(Token::r_square)); + } + + /// Parses a '?' if present. + ParseResult parseOptionalQuestion() override { + return success(parser.consumeIf(Token::question)); + } + + /// Parses a '*' if present. + ParseResult parseOptionalStar() override { + return success(parser.consumeIf(Token::star)); + } + + /// Returns if the current token corresponds to a keyword. + bool isCurrentTokenAKeyword() const { + return parser.getToken().is(Token::bare_identifier) || + parser.getToken().isKeyword(); + } + + /// Parse the given keyword if present. + ParseResult parseOptionalKeyword(StringRef keyword) override { + // Check that the current token has the same spelling. + if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) + return failure(); + parser.consumeToken(); + return success(); + } + + /// Parse a keyword, if present, into 'keyword'. + ParseResult parseOptionalKeyword(StringRef *keyword) override { + // Check that the current token is a keyword. + if (!isCurrentTokenAKeyword()) + return failure(); + + *keyword = parser.getTokenSpelling(); + parser.consumeToken(); + return success(); + } + + //===--------------------------------------------------------------------===// + // Attribute Parsing + //===--------------------------------------------------------------------===// + + /// Parse an arbitrary attribute and return it in result. + ParseResult parseAttribute(Attribute &result, Type type) override { + result = parser.parseAttribute(type); + return success(static_cast(result)); + } + + /// Parse an affine map instance into 'map'. + ParseResult parseAffineMap(AffineMap &map) override { + return parser.parseAffineMapReference(map); + } + + /// Parse an integer set instance into 'set'. + ParseResult printIntegerSet(IntegerSet &set) override { + return parser.parseIntegerSetReference(set); + } + + //===--------------------------------------------------------------------===// + // Type Parsing + //===--------------------------------------------------------------------===// + + ParseResult parseType(Type &result) override { + result = parser.parseType(); + return success(static_cast(result)); + } + + ParseResult parseDimensionList(SmallVectorImpl &dimensions, + bool allowDynamic) override { + return parser.parseDimensionListRanked(dimensions, allowDynamic); + } + +private: + /// The full symbol specification. + StringRef fullSpec; + + /// The source location of the dialect symbol. + SMLoc nameLoc; + + /// The main parser. + Parser &parser; +}; +} // namespace + +/// Parse the body of a pretty dialect symbol, which starts and ends with <>'s, +/// and may be recursive. Return with the 'prettyName' StringRef encompassing +/// the entire pretty name. +/// +/// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' +/// pretty-dialect-sym-contents ::= pretty-dialect-sym-body +/// | '(' pretty-dialect-sym-contents+ ')' +/// | '[' pretty-dialect-sym-contents+ ']' +/// | '{' pretty-dialect-sym-contents+ '}' +/// | '[^[<({>\])}\0]+' +/// +ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) { + // Pretty symbol names are a relatively unstructured format that contains a + // series of properly nested punctuation, with anything else in the middle. + // Scan ahead to find it and consume it if successful, otherwise emit an + // error. + auto *curPtr = getTokenSpelling().data(); + + SmallVector nestedPunctuation; + + // Scan over the nested punctuation, bailing out on error and consuming until + // we find the end. We know that we're currently looking at the '<', so we + // can go until we find the matching '>' character. + assert(*curPtr == '<'); + do { + char c = *curPtr++; + switch (c) { + case '\0': + // This also handles the EOF case. + return emitError("unexpected nul or EOF in pretty dialect name"); + case '<': + case '[': + case '(': + case '{': + nestedPunctuation.push_back(c); + continue; + + case '-': + // The sequence `->` is treated as special token. + if (*curPtr == '>') + ++curPtr; + continue; + + case '>': + if (nestedPunctuation.pop_back_val() != '<') + return emitError("unbalanced '>' character in pretty dialect name"); + break; + case ']': + if (nestedPunctuation.pop_back_val() != '[') + return emitError("unbalanced ']' character in pretty dialect name"); + break; + case ')': + if (nestedPunctuation.pop_back_val() != '(') + return emitError("unbalanced ')' character in pretty dialect name"); + break; + case '}': + if (nestedPunctuation.pop_back_val() != '{') + return emitError("unbalanced '}' character in pretty dialect name"); + break; + + default: + continue; + } + } while (!nestedPunctuation.empty()); + + // Ok, we succeeded, remember where we stopped, reset the lexer to know it is + // consuming all this stuff, and return. + state.lex.resetPointer(curPtr); + + unsigned length = curPtr - prettyName.begin(); + prettyName = StringRef(prettyName.begin(), length); + consumeToken(); + return success(); +} + +/// Parse an extended dialect symbol. +template +static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, + SymbolAliasMap &aliases, + CreateFn &&createSymbol) { + // Parse the dialect namespace. + StringRef identifier = p.getTokenSpelling().drop_front(); + auto loc = p.getToken().getLoc(); + p.consumeToken(identifierTok); + + // If there is no '<' token following this, and if the typename contains no + // dot, then we are parsing a symbol alias. + if (p.getToken().isNot(Token::less) && !identifier.contains('.')) { + // Check for an alias for this type. + auto aliasIt = aliases.find(identifier); + if (aliasIt == aliases.end()) + return (p.emitError("undefined symbol alias id '" + identifier + "'"), + nullptr); + return aliasIt->second; + } + + // Otherwise, we are parsing a dialect-specific symbol. If the name contains + // a dot, then this is the "pretty" form. If not, it is the verbose form that + // looks like <"...">. + std::string symbolData; + auto dialectName = identifier; + + // Handle the verbose form, where "identifier" is a simple dialect name. + if (!identifier.contains('.')) { + // Consume the '<'. + if (p.parseToken(Token::less, "expected '<' in dialect type")) + return nullptr; + + // Parse the symbol specific data. + if (p.getToken().isNot(Token::string)) + return (p.emitError("expected string literal data in dialect symbol"), + nullptr); + symbolData = p.getToken().getStringValue(); + loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1); + p.consumeToken(Token::string); + + // Consume the '>'. + if (p.parseToken(Token::greater, "expected '>' in dialect symbol")) + return nullptr; + } else { + // Ok, the dialect name is the part of the identifier before the dot, the + // part after the dot is the dialect's symbol, or the start thereof. + auto dotHalves = identifier.split('.'); + dialectName = dotHalves.first; + auto prettyName = dotHalves.second; + loc = llvm::SMLoc::getFromPointer(prettyName.data()); + + // If the dialect's symbol is followed immediately by a <, then lex the body + // of it into prettyName. + if (p.getToken().is(Token::less) && + prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) { + if (p.parsePrettyDialectSymbolName(prettyName)) + return nullptr; + } + + symbolData = prettyName.str(); + } + + // Record the name location of the type remapped to the top level buffer. + llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc); + p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer); + + // Call into the provided symbol construction function. + Symbol sym = createSymbol(dialectName, symbolData, loc); + + // Pop the last parser location. + p.getState().symbols.nestedParserLocs.pop_back(); + return sym; +} + +/// Parses a symbol, of type 'T', and returns it if parsing was successful. If +/// parsing failed, nullptr is returned. The number of bytes read from the input +/// string is returned in 'numRead'. +template +static T parseSymbol(StringRef inputStr, MLIRContext *context, + SymbolState &symbolState, ParserFn &&parserFn, + size_t *numRead = nullptr) { + SourceMgr sourceMgr; + auto memBuffer = MemoryBuffer::getMemBuffer( + inputStr, /*BufferName=*/"", + /*RequiresNullTerminator=*/false); + sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); + ParserState state(sourceMgr, context, symbolState); + Parser parser(state); + + Token startTok = parser.getToken(); + T symbol = parserFn(parser); + if (!symbol) + return T(); + + // If 'numRead' is valid, then provide the number of bytes that were read. + Token endTok = parser.getToken(); + if (numRead) { + *numRead = static_cast(endTok.getLoc().getPointer() - + startTok.getLoc().getPointer()); + + // Otherwise, ensure that all of the tokens were parsed. + } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) { + parser.emitError(endTok.getLoc(), "encountered unexpected token"); + return T(); + } + return symbol; +} + +/// Parse an extended attribute. +/// +/// extended-attribute ::= (dialect-attribute | attribute-alias) +/// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>` +/// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body? +/// attribute-alias ::= `#` alias-name +/// +Attribute Parser::parseExtendedAttr(Type type) { + Attribute attr = parseExtendedSymbol( + *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions, + [&](StringRef dialectName, StringRef symbolData, + llvm::SMLoc loc) -> Attribute { + // Parse an optional trailing colon type. + Type attrType = type; + if (consumeIf(Token::colon) && !(attrType = parseType())) + return Attribute(); + + // If we found a registered dialect, then ask it to parse the attribute. + if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { + return parseSymbol( + symbolData, state.context, state.symbols, [&](Parser &parser) { + CustomDialectAsmParser customParser(symbolData, parser); + return dialect->parseAttribute(customParser, attrType); + }); + } + + // Otherwise, form a new opaque attribute. + return OpaqueAttr::getChecked( + Identifier::get(dialectName, state.context), symbolData, + attrType ? attrType : NoneType::get(state.context), + getEncodedSourceLocation(loc)); + }); + + // Ensure that the attribute has the same type as requested. + if (attr && type && attr.getType() != type) { + emitError("attribute type different than expected: expected ") + << type << ", but got " << attr.getType(); + return nullptr; + } + return attr; +} + +/// Parse an extended type. +/// +/// extended-type ::= (dialect-type | type-alias) +/// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>` +/// dialect-type ::= `!` alias-name pretty-dialect-attribute-body? +/// type-alias ::= `!` alias-name +/// +Type Parser::parseExtendedType() { + return parseExtendedSymbol( + *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions, + [&](StringRef dialectName, StringRef symbolData, + llvm::SMLoc loc) -> Type { + // If we found a registered dialect, then ask it to parse the type. + if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { + return parseSymbol( + symbolData, state.context, state.symbols, [&](Parser &parser) { + CustomDialectAsmParser customParser(symbolData, parser); + return dialect->parseType(customParser); + }); + } + + // Otherwise, form a new opaque type. + return OpaqueType::getChecked( + Identifier::get(dialectName, state.context), symbolData, + state.context, getEncodedSourceLocation(loc)); + }); +} + +//===----------------------------------------------------------------------===// +// mlir::parseAttribute/parseType +//===----------------------------------------------------------------------===// + +/// Parses a symbol, of type 'T', and returns it if parsing was successful. If +/// parsing failed, nullptr is returned. The number of bytes read from the input +/// string is returned in 'numRead'. +template +static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead, + ParserFn &&parserFn) { + SymbolState aliasState; + return parseSymbol( + inputStr, context, aliasState, + [&](Parser &parser) { + SourceMgrDiagnosticHandler handler( + const_cast(parser.getSourceMgr()), + parser.getContext()); + return parserFn(parser); + }, + &numRead); +} + +Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) { + size_t numRead = 0; + return parseAttribute(attrStr, context, numRead); +} +Attribute mlir::parseAttribute(StringRef attrStr, Type type) { + size_t numRead = 0; + return parseAttribute(attrStr, type, numRead); +} + +Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context, + size_t &numRead) { + return parseSymbol(attrStr, context, numRead, [](Parser &parser) { + return parser.parseAttribute(); + }); +} +Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) { + return parseSymbol( + attrStr, type.getContext(), numRead, + [type](Parser &parser) { return parser.parseAttribute(type); }); +} + +Type mlir::parseType(StringRef typeStr, MLIRContext *context) { + size_t numRead = 0; + return parseType(typeStr, context, numRead); +} + +Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) { + return parseSymbol(typeStr, context, numRead, + [](Parser &parser) { return parser.parseType(); }); +} diff --git a/mlir/lib/Parser/LocationParser.cpp b/mlir/lib/Parser/LocationParser.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Parser/LocationParser.cpp @@ -0,0 +1,197 @@ +//===- LocationParser.cpp - MLIR Location Parser -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Parser.h" + +using namespace mlir; +using namespace mlir::detail; + +/// Parse a location. +/// +/// location ::= `loc` inline-location +/// inline-location ::= '(' location-inst ')' +/// +ParseResult Parser::parseLocation(LocationAttr &loc) { + // Check for 'loc' identifier. + if (parseToken(Token::kw_loc, "expected 'loc' keyword")) + return emitError(); + + // Parse the inline-location. + if (parseToken(Token::l_paren, "expected '(' in inline location") || + parseLocationInstance(loc) || + parseToken(Token::r_paren, "expected ')' in inline location")) + return failure(); + return success(); +} + +/// Specific location instances. +/// +/// location-inst ::= filelinecol-location | +/// name-location | +/// callsite-location | +/// fused-location | +/// unknown-location +/// filelinecol-location ::= string-literal ':' integer-literal +/// ':' integer-literal +/// name-location ::= string-literal +/// callsite-location ::= 'callsite' '(' location-inst 'at' location-inst ')' +/// fused-location ::= fused ('<' attribute-value '>')? +/// '[' location-inst (location-inst ',')* ']' +/// unknown-location ::= 'unknown' +/// +ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) { + consumeToken(Token::bare_identifier); + + // Parse the '('. + if (parseToken(Token::l_paren, "expected '(' in callsite location")) + return failure(); + + // Parse the callee location. + LocationAttr calleeLoc; + if (parseLocationInstance(calleeLoc)) + return failure(); + + // Parse the 'at'. + if (getToken().isNot(Token::bare_identifier) || + getToken().getSpelling() != "at") + return emitError("expected 'at' in callsite location"); + consumeToken(Token::bare_identifier); + + // Parse the caller location. + LocationAttr callerLoc; + if (parseLocationInstance(callerLoc)) + return failure(); + + // Parse the ')'. + if (parseToken(Token::r_paren, "expected ')' in callsite location")) + return failure(); + + // Return the callsite location. + loc = CallSiteLoc::get(calleeLoc, callerLoc); + return success(); +} + +ParseResult Parser::parseFusedLocation(LocationAttr &loc) { + consumeToken(Token::bare_identifier); + + // Try to parse the optional metadata. + Attribute metadata; + if (consumeIf(Token::less)) { + metadata = parseAttribute(); + if (!metadata) + return emitError("expected valid attribute metadata"); + // Parse the '>' token. + if (parseToken(Token::greater, + "expected '>' after fused location metadata")) + return failure(); + } + + SmallVector locations; + auto parseElt = [&] { + LocationAttr newLoc; + if (parseLocationInstance(newLoc)) + return failure(); + locations.push_back(newLoc); + return success(); + }; + + if (parseToken(Token::l_square, "expected '[' in fused location") || + parseCommaSeparatedList(parseElt) || + parseToken(Token::r_square, "expected ']' in fused location")) + return failure(); + + // Return the fused location. + loc = FusedLoc::get(locations, metadata, getContext()); + return success(); +} + +ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) { + auto *ctx = getContext(); + auto str = getToken().getStringValue(); + consumeToken(Token::string); + + // If the next token is ':' this is a filelinecol location. + if (consumeIf(Token::colon)) { + // Parse the line number. + if (getToken().isNot(Token::integer)) + return emitError("expected integer line number in FileLineColLoc"); + auto line = getToken().getUnsignedIntegerValue(); + if (!line.hasValue()) + return emitError("expected integer line number in FileLineColLoc"); + consumeToken(Token::integer); + + // Parse the ':'. + if (parseToken(Token::colon, "expected ':' in FileLineColLoc")) + return failure(); + + // Parse the column number. + if (getToken().isNot(Token::integer)) + return emitError("expected integer column number in FileLineColLoc"); + auto column = getToken().getUnsignedIntegerValue(); + if (!column.hasValue()) + return emitError("expected integer column number in FileLineColLoc"); + consumeToken(Token::integer); + + loc = FileLineColLoc::get(str, line.getValue(), column.getValue(), ctx); + return success(); + } + + // Otherwise, this is a NameLoc. + + // Check for a child location. + if (consumeIf(Token::l_paren)) { + auto childSourceLoc = getToken().getLoc(); + + // Parse the child location. + LocationAttr childLoc; + if (parseLocationInstance(childLoc)) + return failure(); + + // The child must not be another NameLoc. + if (childLoc.isa()) + return emitError(childSourceLoc, + "child of NameLoc cannot be another NameLoc"); + loc = NameLoc::get(Identifier::get(str, ctx), childLoc); + + // Parse the closing ')'. + if (parseToken(Token::r_paren, + "expected ')' after child location of NameLoc")) + return failure(); + } else { + loc = NameLoc::get(Identifier::get(str, ctx), ctx); + } + + return success(); +} + +ParseResult Parser::parseLocationInstance(LocationAttr &loc) { + // Handle either name or filelinecol locations. + if (getToken().is(Token::string)) + return parseNameOrFileLineColLocation(loc); + + // Bare tokens required for other cases. + if (!getToken().is(Token::bare_identifier)) + return emitError("expected location instance"); + + // Check for the 'callsite' signifying a callsite location. + if (getToken().getSpelling() == "callsite") + return parseCallSiteLocation(loc); + + // If the token is 'fused', then this is a fused location. + if (getToken().getSpelling() == "fused") + return parseFusedLocation(loc); + + // Check for a 'unknown' for an unknown location. + if (getToken().getSpelling() == "unknown") { + consumeToken(Token::bare_identifier); + loc = UnknownLoc::get(getContext()); + return success(); + } + + return emitError("expected location instance"); +} diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Parser/Parser.h @@ -0,0 +1,270 @@ +//===- Parser.h - MLIR Base Parser Class ------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_LIB_PARSER_PARSER_H +#define MLIR_LIB_PARSER_PARSER_H + +#include "ParserState.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" + +namespace mlir { +namespace detail { +//===----------------------------------------------------------------------===// +// Parser +//===----------------------------------------------------------------------===// + +/// This class implement support for parsing global entities like attributes and +/// types. It is intended to be subclassed by specialized subparsers that +/// include state. +class Parser { +public: + Builder builder; + + Parser(ParserState &state) : builder(state.context), state(state) {} + + // Helper methods to get stuff from the parser-global state. + ParserState &getState() const { return state; } + MLIRContext *getContext() const { return state.context; } + const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); } + + /// Parse a comma-separated list of elements up until the specified end token. + ParseResult + parseCommaSeparatedListUntil(Token::Kind rightToken, + const std::function &parseElement, + bool allowEmptyList = true); + + /// Parse a comma separated list of elements that must have at least one entry + /// in it. + ParseResult + parseCommaSeparatedList(const std::function &parseElement); + + ParseResult parsePrettyDialectSymbolName(StringRef &prettyName); + + // We have two forms of parsing methods - those that return a non-null + // pointer on success, and those that return a ParseResult to indicate whether + // they returned a failure. The second class fills in by-reference arguments + // as the results of their action. + + //===--------------------------------------------------------------------===// + // Error Handling + //===--------------------------------------------------------------------===// + + /// Emit an error and return failure. + InFlightDiagnostic emitError(const Twine &message = {}) { + return emitError(state.curToken.getLoc(), message); + } + InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message = {}); + + /// Encode the specified source location information into an attribute for + /// attachment to the IR. + Location getEncodedSourceLocation(llvm::SMLoc loc) { + // If there are no active nested parsers, we can get the encoded source + // location directly. + if (state.parserDepth == 0) + return state.lex.getEncodedSourceLocation(loc); + // Otherwise, we need to re-encode it to point to the top level buffer. + return state.symbols.topLevelLexer->getEncodedSourceLocation( + remapLocationToTopLevelBuffer(loc)); + } + + /// Remaps the given SMLoc to the top level lexer of the parser. This is used + /// to adjust locations of potentially nested parsers to ensure that they can + /// be emitted properly as diagnostics. + llvm::SMLoc remapLocationToTopLevelBuffer(llvm::SMLoc loc) { + // If there are no active nested parsers, we can return location directly. + SymbolState &symbols = state.symbols; + if (state.parserDepth == 0) + return loc; + assert(symbols.topLevelLexer && "expected valid top-level lexer"); + + // Otherwise, we need to remap the location to the main parser. This is + // simply offseting the location onto the location of the last nested + // parser. + size_t offset = loc.getPointer() - state.lex.getBufferBegin(); + auto *rawLoc = + symbols.nestedParserLocs[state.parserDepth - 1].getPointer() + offset; + return llvm::SMLoc::getFromPointer(rawLoc); + } + + //===--------------------------------------------------------------------===// + // Token Parsing + //===--------------------------------------------------------------------===// + + /// Return the current token the parser is inspecting. + const Token &getToken() const { return state.curToken; } + StringRef getTokenSpelling() const { return state.curToken.getSpelling(); } + + /// If the current token has the specified kind, consume it and return true. + /// If not, return false. + bool consumeIf(Token::Kind kind) { + if (state.curToken.isNot(kind)) + return false; + consumeToken(kind); + return true; + } + + /// Advance the current lexer onto the next token. + void consumeToken() { + assert(state.curToken.isNot(Token::eof, Token::error) && + "shouldn't advance past EOF or errors"); + state.curToken = state.lex.lexToken(); + } + + /// Advance the current lexer onto the next token, asserting what the expected + /// current token is. This is preferred to the above method because it leads + /// to more self-documenting code with better checking. + void consumeToken(Token::Kind kind) { + assert(state.curToken.is(kind) && "consumed an unexpected token"); + consumeToken(); + } + + /// Consume the specified token if present and return success. On failure, + /// output a diagnostic and return failure. + ParseResult parseToken(Token::Kind expectedToken, const Twine &message); + + //===--------------------------------------------------------------------===// + // Type Parsing + //===--------------------------------------------------------------------===// + + ParseResult parseFunctionResultTypes(SmallVectorImpl &elements); + ParseResult parseTypeListNoParens(SmallVectorImpl &elements); + ParseResult parseTypeListParens(SmallVectorImpl &elements); + + /// Optionally parse a type. + OptionalParseResult parseOptionalType(Type &type); + + /// Parse an arbitrary type. + Type parseType(); + + /// Parse a complex type. + Type parseComplexType(); + + /// Parse an extended type. + Type parseExtendedType(); + + /// Parse a function type. + Type parseFunctionType(); + + /// Parse a memref type. + Type parseMemRefType(); + + /// Parse a non function type. + Type parseNonFunctionType(); + + /// Parse a tensor type. + Type parseTensorType(); + + /// Parse a tuple type. + Type parseTupleType(); + + /// Parse a vector type. + VectorType parseVectorType(); + ParseResult parseDimensionListRanked(SmallVectorImpl &dimensions, + bool allowDynamic = true); + ParseResult parseXInDimensionList(); + + /// Parse strided layout specification. + ParseResult parseStridedLayout(int64_t &offset, + SmallVectorImpl &strides); + + // Parse a brace-delimiter list of comma-separated integers with `?` as an + // unknown marker. + ParseResult parseStrideList(SmallVectorImpl &dimensions); + + //===--------------------------------------------------------------------===// + // Attribute Parsing + //===--------------------------------------------------------------------===// + + /// Parse an arbitrary attribute with an optional type. + Attribute parseAttribute(Type type = {}); + + /// Parse an attribute dictionary. + ParseResult parseAttributeDict(NamedAttrList &attributes); + + /// Parse an extended attribute. + Attribute parseExtendedAttr(Type type); + + /// Parse a float attribute. + Attribute parseFloatAttr(Type type, bool isNegative); + + /// Parse a decimal or a hexadecimal literal, which can be either an integer + /// or a float attribute. + Attribute parseDecOrHexAttr(Type type, bool isNegative); + + /// Parse an opaque elements attribute. + Attribute parseOpaqueElementsAttr(Type attrType); + + /// Parse a dense elements attribute. + Attribute parseDenseElementsAttr(Type attrType); + ShapedType parseElementsLiteralType(Type type); + + /// Parse a sparse elements attribute. + Attribute parseSparseElementsAttr(Type attrType); + + //===--------------------------------------------------------------------===// + // Location Parsing + //===--------------------------------------------------------------------===// + + /// Parse an inline location. + ParseResult parseLocation(LocationAttr &loc); + + /// Parse a raw location instance. + ParseResult parseLocationInstance(LocationAttr &loc); + + /// Parse a callsite location instance. + ParseResult parseCallSiteLocation(LocationAttr &loc); + + /// Parse a fused location instance. + ParseResult parseFusedLocation(LocationAttr &loc); + + /// Parse a name or FileLineCol location instance. + ParseResult parseNameOrFileLineColLocation(LocationAttr &loc); + + /// Parse an optional trailing location. + /// + /// trailing-location ::= (`loc` `(` location `)`)? + /// + ParseResult parseOptionalTrailingLocation(Location &loc) { + // If there is a 'loc' we parse a trailing location. + if (!getToken().is(Token::kw_loc)) + return success(); + + // Parse the location. + LocationAttr directLoc; + if (parseLocation(directLoc)) + return failure(); + loc = directLoc; + return success(); + } + + //===--------------------------------------------------------------------===// + // Affine Parsing + //===--------------------------------------------------------------------===// + + /// Parse a reference to either an affine map, or an integer set. + ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map, + IntegerSet &set); + ParseResult parseAffineMapReference(AffineMap &map); + ParseResult parseIntegerSetReference(IntegerSet &set); + + /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. + ParseResult + parseAffineMapOfSSAIds(AffineMap &map, + function_ref parseElement, + OpAsmParser::Delimiter delimiter); + +private: + /// The Parser is subclassed and reinstantiated. Do not add additional + /// non-trivial state here, add it to the ParserState class. + ParserState &state; +}; +} // end namespace detail +} // end namespace mlir + +#endif // MLIR_LIB_PARSER_PARSER_H diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -10,3310 +10,86 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Parser.h" -#include "Lexer.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/IntegerSet.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Module.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/Verifier.h" -#include "llvm/ADT/APInt.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringSet.h" -#include "llvm/ADT/bit.h" -#include "llvm/Support/PrettyStackTrace.h" -#include "llvm/Support/SMLoc.h" -#include "llvm/Support/SourceMgr.h" -#include -using namespace mlir; -using llvm::MemoryBuffer; -using llvm::SMLoc; -using llvm::SourceMgr; - -namespace { -class Parser; - -//===----------------------------------------------------------------------===// -// SymbolState -//===----------------------------------------------------------------------===// - -/// This class contains record of any parsed top-level symbols. -struct SymbolState { - // A map from attribute alias identifier to Attribute. - llvm::StringMap attributeAliasDefinitions; - - // A map from type alias identifier to Type. - llvm::StringMap typeAliasDefinitions; - - /// A set of locations into the main parser memory buffer for each of the - /// active nested parsers. Given that some nested parsers, i.e. custom dialect - /// parsers, operate on a temporary memory buffer, this provides an anchor - /// point for emitting diagnostics. - SmallVector nestedParserLocs; - - /// The top-level lexer that contains the original memory buffer provided by - /// the user. This is used by nested parsers to get a properly encoded source - /// location. - Lexer *topLevelLexer = nullptr; -}; - -//===----------------------------------------------------------------------===// -// ParserState -//===----------------------------------------------------------------------===// - -/// This class refers to all of the state maintained globally by the parser, -/// such as the current lexer position etc. -struct ParserState { - ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx, - SymbolState &symbols) - : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()), - symbols(symbols), parserDepth(symbols.nestedParserLocs.size()) { - // Set the top level lexer for the symbol state if one doesn't exist. - if (!symbols.topLevelLexer) - symbols.topLevelLexer = &lex; - } - ~ParserState() { - // Reset the top level lexer if it refers the lexer in our state. - if (symbols.topLevelLexer == &lex) - symbols.topLevelLexer = nullptr; - } - ParserState(const ParserState &) = delete; - void operator=(const ParserState &) = delete; - - /// The context we're parsing into. - MLIRContext *const context; - - /// The lexer for the source file we're parsing. - Lexer lex; - - /// This is the next token that hasn't been consumed yet. - Token curToken; - - /// The current state for symbol parsing. - SymbolState &symbols; - - /// The depth of this parser in the nested parsing stack. - size_t parserDepth; -}; - -//===----------------------------------------------------------------------===// -// Parser -//===----------------------------------------------------------------------===// - -/// This class implement support for parsing global entities like types and -/// shared entities like SSA names. It is intended to be subclassed by -/// specialized subparsers that include state, e.g. when a local symbol table. -class Parser { -public: - Builder builder; - - Parser(ParserState &state) : builder(state.context), state(state) {} - - // Helper methods to get stuff from the parser-global state. - ParserState &getState() const { return state; } - MLIRContext *getContext() const { return state.context; } - const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); } - - /// Parse a comma-separated list of elements up until the specified end token. - ParseResult - parseCommaSeparatedListUntil(Token::Kind rightToken, - const std::function &parseElement, - bool allowEmptyList = true); - - /// Parse a comma separated list of elements that must have at least one entry - /// in it. - ParseResult - parseCommaSeparatedList(const std::function &parseElement); - - ParseResult parsePrettyDialectSymbolName(StringRef &prettyName); - - // We have two forms of parsing methods - those that return a non-null - // pointer on success, and those that return a ParseResult to indicate whether - // they returned a failure. The second class fills in by-reference arguments - // as the results of their action. - - //===--------------------------------------------------------------------===// - // Error Handling - //===--------------------------------------------------------------------===// - - /// Emit an error and return failure. - InFlightDiagnostic emitError(const Twine &message = {}) { - return emitError(state.curToken.getLoc(), message); - } - InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {}); - - /// Encode the specified source location information into an attribute for - /// attachment to the IR. - Location getEncodedSourceLocation(llvm::SMLoc loc) { - // If there are no active nested parsers, we can get the encoded source - // location directly. - if (state.parserDepth == 0) - return state.lex.getEncodedSourceLocation(loc); - // Otherwise, we need to re-encode it to point to the top level buffer. - return state.symbols.topLevelLexer->getEncodedSourceLocation( - remapLocationToTopLevelBuffer(loc)); - } - - /// Remaps the given SMLoc to the top level lexer of the parser. This is used - /// to adjust locations of potentially nested parsers to ensure that they can - /// be emitted properly as diagnostics. - llvm::SMLoc remapLocationToTopLevelBuffer(llvm::SMLoc loc) { - // If there are no active nested parsers, we can return location directly. - SymbolState &symbols = state.symbols; - if (state.parserDepth == 0) - return loc; - assert(symbols.topLevelLexer && "expected valid top-level lexer"); - - // Otherwise, we need to remap the location to the main parser. This is - // simply offseting the location onto the location of the last nested - // parser. - size_t offset = loc.getPointer() - state.lex.getBufferBegin(); - auto *rawLoc = - symbols.nestedParserLocs[state.parserDepth - 1].getPointer() + offset; - return llvm::SMLoc::getFromPointer(rawLoc); - } - - //===--------------------------------------------------------------------===// - // Token Parsing - //===--------------------------------------------------------------------===// - - /// Return the current token the parser is inspecting. - const Token &getToken() const { return state.curToken; } - StringRef getTokenSpelling() const { return state.curToken.getSpelling(); } - - /// If the current token has the specified kind, consume it and return true. - /// If not, return false. - bool consumeIf(Token::Kind kind) { - if (state.curToken.isNot(kind)) - return false; - consumeToken(kind); - return true; - } - - /// Advance the current lexer onto the next token. - void consumeToken() { - assert(state.curToken.isNot(Token::eof, Token::error) && - "shouldn't advance past EOF or errors"); - state.curToken = state.lex.lexToken(); - } - - /// Advance the current lexer onto the next token, asserting what the expected - /// current token is. This is preferred to the above method because it leads - /// to more self-documenting code with better checking. - void consumeToken(Token::Kind kind) { - assert(state.curToken.is(kind) && "consumed an unexpected token"); - consumeToken(); - } - - /// Consume the specified token if present and return success. On failure, - /// output a diagnostic and return failure. - ParseResult parseToken(Token::Kind expectedToken, const Twine &message); - - //===--------------------------------------------------------------------===// - // Type Parsing - //===--------------------------------------------------------------------===// - - ParseResult parseFunctionResultTypes(SmallVectorImpl &elements); - ParseResult parseTypeListNoParens(SmallVectorImpl &elements); - ParseResult parseTypeListParens(SmallVectorImpl &elements); - - /// Optionally parse a type. - OptionalParseResult parseOptionalType(Type &type); - - /// Parse an arbitrary type. - Type parseType(); - - /// Parse a complex type. - Type parseComplexType(); - - /// Parse an extended type. - Type parseExtendedType(); - - /// Parse a function type. - Type parseFunctionType(); - - /// Parse a memref type. - Type parseMemRefType(); - - /// Parse a non function type. - Type parseNonFunctionType(); - - /// Parse a tensor type. - Type parseTensorType(); - - /// Parse a tuple type. - Type parseTupleType(); - - /// Parse a vector type. - VectorType parseVectorType(); - ParseResult parseDimensionListRanked(SmallVectorImpl &dimensions, - bool allowDynamic = true); - ParseResult parseXInDimensionList(); - - /// Parse strided layout specification. - ParseResult parseStridedLayout(int64_t &offset, - SmallVectorImpl &strides); - - // Parse a brace-delimiter list of comma-separated integers with `?` as an - // unknown marker. - ParseResult parseStrideList(SmallVectorImpl &dimensions); - - //===--------------------------------------------------------------------===// - // Attribute Parsing - //===--------------------------------------------------------------------===// - - /// Parse an arbitrary attribute with an optional type. - Attribute parseAttribute(Type type = {}); - - /// Parse an attribute dictionary. - ParseResult parseAttributeDict(NamedAttrList &attributes); - - /// Parse an extended attribute. - Attribute parseExtendedAttr(Type type); - - /// Parse a float attribute. - Attribute parseFloatAttr(Type type, bool isNegative); - - /// Parse a decimal or a hexadecimal literal, which can be either an integer - /// or a float attribute. - Attribute parseDecOrHexAttr(Type type, bool isNegative); - - /// Parse an opaque elements attribute. - Attribute parseOpaqueElementsAttr(Type attrType); - - /// Parse a dense elements attribute. - Attribute parseDenseElementsAttr(Type attrType); - ShapedType parseElementsLiteralType(Type type); - - /// Parse a sparse elements attribute. - Attribute parseSparseElementsAttr(Type attrType); - - //===--------------------------------------------------------------------===// - // Location Parsing - //===--------------------------------------------------------------------===// - - /// Parse an inline location. - ParseResult parseLocation(LocationAttr &loc); - - /// Parse a raw location instance. - ParseResult parseLocationInstance(LocationAttr &loc); - - /// Parse a callsite location instance. - ParseResult parseCallSiteLocation(LocationAttr &loc); - - /// Parse a fused location instance. - ParseResult parseFusedLocation(LocationAttr &loc); - - /// Parse a name or FileLineCol location instance. - ParseResult parseNameOrFileLineColLocation(LocationAttr &loc); - - /// Parse an optional trailing location. - /// - /// trailing-location ::= (`loc` `(` location `)`)? - /// - ParseResult parseOptionalTrailingLocation(Location &loc) { - // If there is a 'loc' we parse a trailing location. - if (!getToken().is(Token::kw_loc)) - return success(); - - // Parse the location. - LocationAttr directLoc; - if (parseLocation(directLoc)) - return failure(); - loc = directLoc; - return success(); - } - - //===--------------------------------------------------------------------===// - // Affine Parsing - //===--------------------------------------------------------------------===// - - /// Parse a reference to either an affine map, or an integer set. - ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map, - IntegerSet &set); - ParseResult parseAffineMapReference(AffineMap &map); - ParseResult parseIntegerSetReference(IntegerSet &set); - - /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. - ParseResult - parseAffineMapOfSSAIds(AffineMap &map, - function_ref parseElement, - OpAsmParser::Delimiter delimiter); - -private: - /// The Parser is subclassed and reinstantiated. Do not add additional - /// non-trivial state here, add it to the ParserState class. - ParserState &state; -}; -} // end anonymous namespace - -//===----------------------------------------------------------------------===// -// Helper methods. -//===----------------------------------------------------------------------===// - -/// Parse a comma separated list of elements that must have at least one entry -/// in it. -ParseResult Parser::parseCommaSeparatedList( - const std::function &parseElement) { - // Non-empty case starts with an element. - if (parseElement()) - return failure(); - - // Otherwise we have a list of comma separated elements. - while (consumeIf(Token::comma)) { - if (parseElement()) - return failure(); - } - return success(); -} - -/// Parse a comma-separated list of elements, terminated with an arbitrary -/// token. This allows empty lists if allowEmptyList is true. -/// -/// abstract-list ::= rightToken // if allowEmptyList == true -/// abstract-list ::= element (',' element)* rightToken -/// -ParseResult Parser::parseCommaSeparatedListUntil( - Token::Kind rightToken, const std::function &parseElement, - bool allowEmptyList) { - // Handle the empty case. - if (getToken().is(rightToken)) { - if (!allowEmptyList) - return emitError("expected list element"); - consumeToken(rightToken); - return success(); - } - - if (parseCommaSeparatedList(parseElement) || - parseToken(rightToken, "expected ',' or '" + - Token::getTokenSpelling(rightToken) + "'")) - return failure(); - - return success(); -} - -//===----------------------------------------------------------------------===// -// DialectAsmParser -//===----------------------------------------------------------------------===// - -namespace { -/// This class provides the main implementation of the DialectAsmParser that -/// allows for dialects to parse attributes and types. This allows for dialect -/// hooking into the main MLIR parsing logic. -class CustomDialectAsmParser : public DialectAsmParser { -public: - CustomDialectAsmParser(StringRef fullSpec, Parser &parser) - : fullSpec(fullSpec), nameLoc(parser.getToken().getLoc()), - parser(parser) {} - ~CustomDialectAsmParser() override {} - - /// Emit a diagnostic at the specified location and return failure. - InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { - return parser.emitError(loc, message); - } - - /// Return a builder which provides useful access to MLIRContext, global - /// objects like types and attributes. - Builder &getBuilder() const override { return parser.builder; } - - /// Get the location of the next token and store it into the argument. This - /// always succeeds. - llvm::SMLoc getCurrentLocation() override { - return parser.getToken().getLoc(); - } - - /// Return the location of the original name token. - llvm::SMLoc getNameLoc() const override { return nameLoc; } - - /// Re-encode the given source location as an MLIR location and return it. - Location getEncodedSourceLoc(llvm::SMLoc loc) override { - return parser.getEncodedSourceLocation(loc); - } - - /// Returns the full specification of the symbol being parsed. This allows - /// for using a separate parser if necessary. - StringRef getFullSymbolSpec() const override { return fullSpec; } - - /// Parse a floating point value from the stream. - ParseResult parseFloat(double &result) override { - bool negative = parser.consumeIf(Token::minus); - Token curTok = parser.getToken(); - - // Check for a floating point value. - if (curTok.is(Token::floatliteral)) { - auto val = curTok.getFloatingPointValue(); - if (!val.hasValue()) - return emitError(curTok.getLoc(), "floating point value too large"); - parser.consumeToken(Token::floatliteral); - result = negative ? -*val : *val; - return success(); - } - - // TODO(riverriddle) support hex floating point values. - return emitError(getCurrentLocation(), "expected floating point literal"); - } - - /// Parse an optional integer value from the stream. - OptionalParseResult parseOptionalInteger(uint64_t &result) override { - Token curToken = parser.getToken(); - if (curToken.isNot(Token::integer, Token::minus)) - return llvm::None; - - bool negative = parser.consumeIf(Token::minus); - Token curTok = parser.getToken(); - if (parser.parseToken(Token::integer, "expected integer value")) - return failure(); - - auto val = curTok.getUInt64IntegerValue(); - if (!val) - return emitError(curTok.getLoc(), "integer value too large"); - result = negative ? -*val : *val; - return success(); - } - - //===--------------------------------------------------------------------===// - // Token Parsing - //===--------------------------------------------------------------------===// - - /// Parse a `->` token. - ParseResult parseArrow() override { - return parser.parseToken(Token::arrow, "expected '->'"); - } - - /// Parses a `->` if present. - ParseResult parseOptionalArrow() override { - return success(parser.consumeIf(Token::arrow)); - } - - /// Parse a '{' token. - ParseResult parseLBrace() override { - return parser.parseToken(Token::l_brace, "expected '{'"); - } - - /// Parse a '{' token if present - ParseResult parseOptionalLBrace() override { - return success(parser.consumeIf(Token::l_brace)); - } - - /// Parse a `}` token. - ParseResult parseRBrace() override { - return parser.parseToken(Token::r_brace, "expected '}'"); - } - - /// Parse a `}` token if present - ParseResult parseOptionalRBrace() override { - return success(parser.consumeIf(Token::r_brace)); - } - - /// Parse a `:` token. - ParseResult parseColon() override { - return parser.parseToken(Token::colon, "expected ':'"); - } - - /// Parse a `:` token if present. - ParseResult parseOptionalColon() override { - return success(parser.consumeIf(Token::colon)); - } - - /// Parse a `,` token. - ParseResult parseComma() override { - return parser.parseToken(Token::comma, "expected ','"); - } - - /// Parse a `,` token if present. - ParseResult parseOptionalComma() override { - return success(parser.consumeIf(Token::comma)); - } - - /// Parses a `...` if present. - ParseResult parseOptionalEllipsis() override { - return success(parser.consumeIf(Token::ellipsis)); - } - - /// Parse a `=` token. - ParseResult parseEqual() override { - return parser.parseToken(Token::equal, "expected '='"); - } - - /// Parse a '<' token. - ParseResult parseLess() override { - return parser.parseToken(Token::less, "expected '<'"); - } - - /// Parse a `<` token if present. - ParseResult parseOptionalLess() override { - return success(parser.consumeIf(Token::less)); - } - - /// Parse a '>' token. - ParseResult parseGreater() override { - return parser.parseToken(Token::greater, "expected '>'"); - } - - /// Parse a `>` token if present. - ParseResult parseOptionalGreater() override { - return success(parser.consumeIf(Token::greater)); - } - - /// Parse a `(` token. - ParseResult parseLParen() override { - return parser.parseToken(Token::l_paren, "expected '('"); - } - - /// Parses a '(' if present. - ParseResult parseOptionalLParen() override { - return success(parser.consumeIf(Token::l_paren)); - } - - /// Parse a `)` token. - ParseResult parseRParen() override { - return parser.parseToken(Token::r_paren, "expected ')'"); - } - - /// Parses a ')' if present. - ParseResult parseOptionalRParen() override { - return success(parser.consumeIf(Token::r_paren)); - } - - /// Parse a `[` token. - ParseResult parseLSquare() override { - return parser.parseToken(Token::l_square, "expected '['"); - } - - /// Parses a '[' if present. - ParseResult parseOptionalLSquare() override { - return success(parser.consumeIf(Token::l_square)); - } - - /// Parse a `]` token. - ParseResult parseRSquare() override { - return parser.parseToken(Token::r_square, "expected ']'"); - } - - /// Parses a ']' if present. - ParseResult parseOptionalRSquare() override { - return success(parser.consumeIf(Token::r_square)); - } - - /// Parses a '?' if present. - ParseResult parseOptionalQuestion() override { - return success(parser.consumeIf(Token::question)); - } - - /// Parses a '*' if present. - ParseResult parseOptionalStar() override { - return success(parser.consumeIf(Token::star)); - } - - /// Returns if the current token corresponds to a keyword. - bool isCurrentTokenAKeyword() const { - return parser.getToken().is(Token::bare_identifier) || - parser.getToken().isKeyword(); - } - - /// Parse the given keyword if present. - ParseResult parseOptionalKeyword(StringRef keyword) override { - // Check that the current token has the same spelling. - if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) - return failure(); - parser.consumeToken(); - return success(); - } - - /// Parse a keyword, if present, into 'keyword'. - ParseResult parseOptionalKeyword(StringRef *keyword) override { - // Check that the current token is a keyword. - if (!isCurrentTokenAKeyword()) - return failure(); - - *keyword = parser.getTokenSpelling(); - parser.consumeToken(); - return success(); - } - - //===--------------------------------------------------------------------===// - // Attribute Parsing - //===--------------------------------------------------------------------===// - - /// Parse an arbitrary attribute and return it in result. - ParseResult parseAttribute(Attribute &result, Type type) override { - result = parser.parseAttribute(type); - return success(static_cast(result)); - } - - /// Parse an affine map instance into 'map'. - ParseResult parseAffineMap(AffineMap &map) override { - return parser.parseAffineMapReference(map); - } - - /// Parse an integer set instance into 'set'. - ParseResult printIntegerSet(IntegerSet &set) override { - return parser.parseIntegerSetReference(set); - } - - //===--------------------------------------------------------------------===// - // Type Parsing - //===--------------------------------------------------------------------===// - - ParseResult parseType(Type &result) override { - result = parser.parseType(); - return success(static_cast(result)); - } - - ParseResult parseDimensionList(SmallVectorImpl &dimensions, - bool allowDynamic) override { - return parser.parseDimensionListRanked(dimensions, allowDynamic); - } - -private: - /// The full symbol specification. - StringRef fullSpec; - - /// The source location of the dialect symbol. - SMLoc nameLoc; - - /// The main parser. - Parser &parser; -}; -} // namespace - -/// Parse the body of a pretty dialect symbol, which starts and ends with <>'s, -/// and may be recursive. Return with the 'prettyName' StringRef encompassing -/// the entire pretty name. -/// -/// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' -/// pretty-dialect-sym-contents ::= pretty-dialect-sym-body -/// | '(' pretty-dialect-sym-contents+ ')' -/// | '[' pretty-dialect-sym-contents+ ']' -/// | '{' pretty-dialect-sym-contents+ '}' -/// | '[^[<({>\])}\0]+' -/// -ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) { - // Pretty symbol names are a relatively unstructured format that contains a - // series of properly nested punctuation, with anything else in the middle. - // Scan ahead to find it and consume it if successful, otherwise emit an - // error. - auto *curPtr = getTokenSpelling().data(); - - SmallVector nestedPunctuation; - - // Scan over the nested punctuation, bailing out on error and consuming until - // we find the end. We know that we're currently looking at the '<', so we - // can go until we find the matching '>' character. - assert(*curPtr == '<'); - do { - char c = *curPtr++; - switch (c) { - case '\0': - // This also handles the EOF case. - return emitError("unexpected nul or EOF in pretty dialect name"); - case '<': - case '[': - case '(': - case '{': - nestedPunctuation.push_back(c); - continue; - - case '-': - // The sequence `->` is treated as special token. - if (*curPtr == '>') - ++curPtr; - continue; - - case '>': - if (nestedPunctuation.pop_back_val() != '<') - return emitError("unbalanced '>' character in pretty dialect name"); - break; - case ']': - if (nestedPunctuation.pop_back_val() != '[') - return emitError("unbalanced ']' character in pretty dialect name"); - break; - case ')': - if (nestedPunctuation.pop_back_val() != '(') - return emitError("unbalanced ')' character in pretty dialect name"); - break; - case '}': - if (nestedPunctuation.pop_back_val() != '{') - return emitError("unbalanced '}' character in pretty dialect name"); - break; - - default: - continue; - } - } while (!nestedPunctuation.empty()); - - // Ok, we succeeded, remember where we stopped, reset the lexer to know it is - // consuming all this stuff, and return. - state.lex.resetPointer(curPtr); - - unsigned length = curPtr - prettyName.begin(); - prettyName = StringRef(prettyName.begin(), length); - consumeToken(); - return success(); -} - -/// Parse an extended dialect symbol. -template -static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, - SymbolAliasMap &aliases, - CreateFn &&createSymbol) { - // Parse the dialect namespace. - StringRef identifier = p.getTokenSpelling().drop_front(); - auto loc = p.getToken().getLoc(); - p.consumeToken(identifierTok); - - // If there is no '<' token following this, and if the typename contains no - // dot, then we are parsing a symbol alias. - if (p.getToken().isNot(Token::less) && !identifier.contains('.')) { - // Check for an alias for this type. - auto aliasIt = aliases.find(identifier); - if (aliasIt == aliases.end()) - return (p.emitError("undefined symbol alias id '" + identifier + "'"), - nullptr); - return aliasIt->second; - } - - // Otherwise, we are parsing a dialect-specific symbol. If the name contains - // a dot, then this is the "pretty" form. If not, it is the verbose form that - // looks like <"...">. - std::string symbolData; - auto dialectName = identifier; - - // Handle the verbose form, where "identifier" is a simple dialect name. - if (!identifier.contains('.')) { - // Consume the '<'. - if (p.parseToken(Token::less, "expected '<' in dialect type")) - return nullptr; - - // Parse the symbol specific data. - if (p.getToken().isNot(Token::string)) - return (p.emitError("expected string literal data in dialect symbol"), - nullptr); - symbolData = p.getToken().getStringValue(); - loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1); - p.consumeToken(Token::string); - - // Consume the '>'. - if (p.parseToken(Token::greater, "expected '>' in dialect symbol")) - return nullptr; - } else { - // Ok, the dialect name is the part of the identifier before the dot, the - // part after the dot is the dialect's symbol, or the start thereof. - auto dotHalves = identifier.split('.'); - dialectName = dotHalves.first; - auto prettyName = dotHalves.second; - loc = llvm::SMLoc::getFromPointer(prettyName.data()); - - // If the dialect's symbol is followed immediately by a <, then lex the body - // of it into prettyName. - if (p.getToken().is(Token::less) && - prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) { - if (p.parsePrettyDialectSymbolName(prettyName)) - return nullptr; - } - - symbolData = prettyName.str(); - } - - // Record the name location of the type remapped to the top level buffer. - llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc); - p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer); - - // Call into the provided symbol construction function. - Symbol sym = createSymbol(dialectName, symbolData, loc); - - // Pop the last parser location. - p.getState().symbols.nestedParserLocs.pop_back(); - return sym; -} - -/// Parses a symbol, of type 'T', and returns it if parsing was successful. If -/// parsing failed, nullptr is returned. The number of bytes read from the input -/// string is returned in 'numRead'. -template -static T parseSymbol(StringRef inputStr, MLIRContext *context, - SymbolState &symbolState, ParserFn &&parserFn, - size_t *numRead = nullptr) { - SourceMgr sourceMgr; - auto memBuffer = MemoryBuffer::getMemBuffer( - inputStr, /*BufferName=*/"", - /*RequiresNullTerminator=*/false); - sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); - ParserState state(sourceMgr, context, symbolState); - Parser parser(state); - - Token startTok = parser.getToken(); - T symbol = parserFn(parser); - if (!symbol) - return T(); - - // If 'numRead' is valid, then provide the number of bytes that were read. - Token endTok = parser.getToken(); - if (numRead) { - *numRead = static_cast(endTok.getLoc().getPointer() - - startTok.getLoc().getPointer()); - - // Otherwise, ensure that all of the tokens were parsed. - } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) { - parser.emitError(endTok.getLoc(), "encountered unexpected token"); - return T(); - } - return symbol; -} - -//===----------------------------------------------------------------------===// -// Error Handling -//===----------------------------------------------------------------------===// - -InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) { - auto diag = mlir::emitError(getEncodedSourceLocation(loc), message); - - // If we hit a parse error in response to a lexer error, then the lexer - // already reported the error. - if (getToken().is(Token::error)) - diag.abandon(); - return diag; -} - -//===----------------------------------------------------------------------===// -// Token Parsing -//===----------------------------------------------------------------------===// - -/// Consume the specified token if present and return success. On failure, -/// output a diagnostic and return failure. -ParseResult Parser::parseToken(Token::Kind expectedToken, - const Twine &message) { - if (consumeIf(expectedToken)) - return success(); - return emitError(message); -} - -//===----------------------------------------------------------------------===// -// Type Parsing -//===----------------------------------------------------------------------===// - -/// Optionally parse a type. -OptionalParseResult Parser::parseOptionalType(Type &type) { - // There are many different starting tokens for a type, check them here. - switch (getToken().getKind()) { - case Token::l_paren: - case Token::kw_memref: - case Token::kw_tensor: - case Token::kw_complex: - case Token::kw_tuple: - case Token::kw_vector: - case Token::inttype: - case Token::kw_bf16: - case Token::kw_f16: - case Token::kw_f32: - case Token::kw_f64: - case Token::kw_index: - case Token::kw_none: - case Token::exclamation_identifier: - return failure(!(type = parseType())); - - default: - return llvm::None; - } -} - -/// Parse an arbitrary type. -/// -/// type ::= function-type -/// | non-function-type -/// -Type Parser::parseType() { - if (getToken().is(Token::l_paren)) - return parseFunctionType(); - return parseNonFunctionType(); -} - -/// Parse a function result type. -/// -/// function-result-type ::= type-list-parens -/// | non-function-type -/// -ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl &elements) { - if (getToken().is(Token::l_paren)) - return parseTypeListParens(elements); - - Type t = parseNonFunctionType(); - if (!t) - return failure(); - elements.push_back(t); - return success(); -} - -/// Parse a list of types without an enclosing parenthesis. The list must have -/// at least one member. -/// -/// type-list-no-parens ::= type (`,` type)* -/// -ParseResult Parser::parseTypeListNoParens(SmallVectorImpl &elements) { - auto parseElt = [&]() -> ParseResult { - auto elt = parseType(); - elements.push_back(elt); - return elt ? success() : failure(); - }; - - return parseCommaSeparatedList(parseElt); -} - -/// Parse a parenthesized list of types. -/// -/// type-list-parens ::= `(` `)` -/// | `(` type-list-no-parens `)` -/// -ParseResult Parser::parseTypeListParens(SmallVectorImpl &elements) { - if (parseToken(Token::l_paren, "expected '('")) - return failure(); - - // Handle empty lists. - if (getToken().is(Token::r_paren)) - return consumeToken(), success(); - - if (parseTypeListNoParens(elements) || - parseToken(Token::r_paren, "expected ')'")) - return failure(); - return success(); -} - -/// Parse a complex type. -/// -/// complex-type ::= `complex` `<` type `>` -/// -Type Parser::parseComplexType() { - consumeToken(Token::kw_complex); - - // Parse the '<'. - if (parseToken(Token::less, "expected '<' in complex type")) - return nullptr; - - llvm::SMLoc elementTypeLoc = getToken().getLoc(); - auto elementType = parseType(); - if (!elementType || - parseToken(Token::greater, "expected '>' in complex type")) - return nullptr; - if (!elementType.isa() && !elementType.isa()) - return emitError(elementTypeLoc, "invalid element type for complex"), - nullptr; - - return ComplexType::get(elementType); -} - -/// Parse an extended type. -/// -/// extended-type ::= (dialect-type | type-alias) -/// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>` -/// dialect-type ::= `!` alias-name pretty-dialect-attribute-body? -/// type-alias ::= `!` alias-name -/// -Type Parser::parseExtendedType() { - return parseExtendedSymbol( - *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions, - [&](StringRef dialectName, StringRef symbolData, - llvm::SMLoc loc) -> Type { - // If we found a registered dialect, then ask it to parse the type. - if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { - return parseSymbol( - symbolData, state.context, state.symbols, [&](Parser &parser) { - CustomDialectAsmParser customParser(symbolData, parser); - return dialect->parseType(customParser); - }); - } - - // Otherwise, form a new opaque type. - return OpaqueType::getChecked( - Identifier::get(dialectName, state.context), symbolData, - state.context, getEncodedSourceLocation(loc)); - }); -} - -/// Parse a function type. -/// -/// function-type ::= type-list-parens `->` function-result-type -/// -Type Parser::parseFunctionType() { - assert(getToken().is(Token::l_paren)); - - SmallVector arguments, results; - if (parseTypeListParens(arguments) || - parseToken(Token::arrow, "expected '->' in function type") || - parseFunctionResultTypes(results)) - return nullptr; - - return builder.getFunctionType(arguments, results); -} - -/// Parse the offset and strides from a strided layout specification. -/// -/// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list -/// -ParseResult Parser::parseStridedLayout(int64_t &offset, - SmallVectorImpl &strides) { - // Parse offset. - consumeToken(Token::kw_offset); - if (!consumeIf(Token::colon)) - return emitError("expected colon after `offset` keyword"); - auto maybeOffset = getToken().getUnsignedIntegerValue(); - bool question = getToken().is(Token::question); - if (!maybeOffset && !question) - return emitError("invalid offset"); - offset = maybeOffset ? static_cast(maybeOffset.getValue()) - : MemRefType::getDynamicStrideOrOffset(); - consumeToken(); - - if (!consumeIf(Token::comma)) - return emitError("expected comma after offset value"); - - // Parse stride list. - if (!consumeIf(Token::kw_strides)) - return emitError("expected `strides` keyword after offset specification"); - if (!consumeIf(Token::colon)) - return emitError("expected colon after `strides` keyword"); - if (failed(parseStrideList(strides))) - return emitError("invalid braces-enclosed stride list"); - if (llvm::any_of(strides, [](int64_t st) { return st == 0; })) - return emitError("invalid memref stride"); - - return success(); -} - -/// Parse a memref type. -/// -/// memref-type ::= ranked-memref-type | unranked-memref-type -/// -/// ranked-memref-type ::= `memref` `<` dimension-list-ranked type -/// (`,` semi-affine-map-composition)? (`,` -/// memory-space)? `>` -/// -/// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>` -/// -/// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map -/// memory-space ::= integer-literal /* | TODO: address-space-id */ -/// -Type Parser::parseMemRefType() { - consumeToken(Token::kw_memref); - - if (parseToken(Token::less, "expected '<' in memref type")) - return nullptr; - - bool isUnranked; - SmallVector dimensions; - - if (consumeIf(Token::star)) { - // This is an unranked memref type. - isUnranked = true; - if (parseXInDimensionList()) - return nullptr; - - } else { - isUnranked = false; - if (parseDimensionListRanked(dimensions)) - return nullptr; - } - - // Parse the element type. - auto typeLoc = getToken().getLoc(); - auto elementType = parseType(); - if (!elementType) - return nullptr; - - // Check that memref is formed from allowed types. - if (!elementType.isIntOrFloat() && !elementType.isa() && - !elementType.isa()) - return emitError(typeLoc, "invalid memref element type"), nullptr; - - // Parse semi-affine-map-composition. - SmallVector affineMapComposition; - Optional memorySpace; - unsigned numDims = dimensions.size(); - - auto parseElt = [&]() -> ParseResult { - // Check for the memory space. - if (getToken().is(Token::integer)) { - if (memorySpace) - return emitError("multiple memory spaces specified in memref type"); - memorySpace = getToken().getUnsignedIntegerValue(); - if (!memorySpace.hasValue()) - return emitError("invalid memory space in memref type"); - consumeToken(Token::integer); - return success(); - } - if (isUnranked) - return emitError("cannot have affine map for unranked memref type"); - if (memorySpace) - return emitError("expected memory space to be last in memref type"); - - AffineMap map; - llvm::SMLoc mapLoc = getToken().getLoc(); - if (getToken().is(Token::kw_offset)) { - int64_t offset; - SmallVector strides; - if (failed(parseStridedLayout(offset, strides))) - return failure(); - // Construct strided affine map. - map = makeStridedLinearLayoutMap(strides, offset, state.context); - } else { - // Parse an affine map attribute. - auto affineMap = parseAttribute(); - if (!affineMap) - return failure(); - auto affineMapAttr = affineMap.dyn_cast(); - if (!affineMapAttr) - return emitError("expected affine map in memref type"); - map = affineMapAttr.getValue(); - } - - if (map.getNumDims() != numDims) { - size_t i = affineMapComposition.size(); - return emitError(mapLoc, "memref affine map dimension mismatch between ") - << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) - << " and affine map" << i + 1 << ": " << numDims - << " != " << map.getNumDims(); - } - numDims = map.getNumResults(); - affineMapComposition.push_back(map); - return success(); - }; - - // Parse a list of mappings and address space if present. - if (!consumeIf(Token::greater)) { - // Parse comma separated list of affine maps, followed by memory space. - if (parseToken(Token::comma, "expected ',' or '>' in memref type") || - parseCommaSeparatedListUntil(Token::greater, parseElt, - /*allowEmptyList=*/false)) { - return nullptr; - } - } - - if (isUnranked) - return UnrankedMemRefType::get(elementType, memorySpace.getValueOr(0)); - - return MemRefType::get(dimensions, elementType, affineMapComposition, - memorySpace.getValueOr(0)); -} - -/// Parse any type except the function type. -/// -/// non-function-type ::= integer-type -/// | index-type -/// | float-type -/// | extended-type -/// | vector-type -/// | tensor-type -/// | memref-type -/// | complex-type -/// | tuple-type -/// | none-type -/// -/// index-type ::= `index` -/// float-type ::= `f16` | `bf16` | `f32` | `f64` -/// none-type ::= `none` -/// -Type Parser::parseNonFunctionType() { - switch (getToken().getKind()) { - default: - return (emitError("expected non-function type"), nullptr); - case Token::kw_memref: - return parseMemRefType(); - case Token::kw_tensor: - return parseTensorType(); - case Token::kw_complex: - return parseComplexType(); - case Token::kw_tuple: - return parseTupleType(); - case Token::kw_vector: - return parseVectorType(); - // integer-type - case Token::inttype: { - auto width = getToken().getIntTypeBitwidth(); - if (!width.hasValue()) - return (emitError("invalid integer width"), nullptr); - if (width.getValue() > IntegerType::kMaxWidth) { - emitError(getToken().getLoc(), "integer bitwidth is limited to ") - << IntegerType::kMaxWidth << " bits"; - return nullptr; - } - - IntegerType::SignednessSemantics signSemantics = IntegerType::Signless; - if (Optional signedness = getToken().getIntTypeSignedness()) - signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned; - - auto loc = getEncodedSourceLocation(getToken().getLoc()); - consumeToken(Token::inttype); - return IntegerType::getChecked(width.getValue(), signSemantics, loc); - } - - // float-type - case Token::kw_bf16: - consumeToken(Token::kw_bf16); - return builder.getBF16Type(); - case Token::kw_f16: - consumeToken(Token::kw_f16); - return builder.getF16Type(); - case Token::kw_f32: - consumeToken(Token::kw_f32); - return builder.getF32Type(); - case Token::kw_f64: - consumeToken(Token::kw_f64); - return builder.getF64Type(); - - // index-type - case Token::kw_index: - consumeToken(Token::kw_index); - return builder.getIndexType(); - - // none-type - case Token::kw_none: - consumeToken(Token::kw_none); - return builder.getNoneType(); - - // extended type - case Token::exclamation_identifier: - return parseExtendedType(); - } -} - -/// Parse a tensor type. -/// -/// tensor-type ::= `tensor` `<` dimension-list type `>` -/// dimension-list ::= dimension-list-ranked | `*x` -/// -Type Parser::parseTensorType() { - consumeToken(Token::kw_tensor); - - if (parseToken(Token::less, "expected '<' in tensor type")) - return nullptr; - - bool isUnranked; - SmallVector dimensions; - - if (consumeIf(Token::star)) { - // This is an unranked tensor type. - isUnranked = true; - - if (parseXInDimensionList()) - return nullptr; - - } else { - isUnranked = false; - if (parseDimensionListRanked(dimensions)) - return nullptr; - } - - // Parse the element type. - auto elementTypeLoc = getToken().getLoc(); - auto elementType = parseType(); - if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) - return nullptr; - if (!TensorType::isValidElementType(elementType)) - return emitError(elementTypeLoc, "invalid tensor element type"), nullptr; - - if (isUnranked) - return UnrankedTensorType::get(elementType); - return RankedTensorType::get(dimensions, elementType); -} - -/// Parse a tuple type. -/// -/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` -/// -Type Parser::parseTupleType() { - consumeToken(Token::kw_tuple); - - // Parse the '<'. - if (parseToken(Token::less, "expected '<' in tuple type")) - return nullptr; - - // Check for an empty tuple by directly parsing '>'. - if (consumeIf(Token::greater)) - return TupleType::get(getContext()); - - // Parse the element types and the '>'. - SmallVector types; - if (parseTypeListNoParens(types) || - parseToken(Token::greater, "expected '>' in tuple type")) - return nullptr; - - return TupleType::get(types, getContext()); -} - -/// Parse a vector type. -/// -/// vector-type ::= `vector` `<` non-empty-static-dimension-list type `>` -/// non-empty-static-dimension-list ::= decimal-literal `x` -/// static-dimension-list -/// static-dimension-list ::= (decimal-literal `x`)* -/// -VectorType Parser::parseVectorType() { - consumeToken(Token::kw_vector); - - if (parseToken(Token::less, "expected '<' in vector type")) - return nullptr; - - SmallVector dimensions; - if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false)) - return nullptr; - if (dimensions.empty()) - return (emitError("expected dimension size in vector type"), nullptr); - if (any_of(dimensions, [](int64_t i) { return i <= 0; })) - return emitError(getToken().getLoc(), - "vector types must have positive constant sizes"), - nullptr; - - // Parse the element type. - auto typeLoc = getToken().getLoc(); - auto elementType = parseType(); - if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) - return nullptr; - if (!VectorType::isValidElementType(elementType)) - return emitError(typeLoc, "vector elements must be int or float type"), - nullptr; - - return VectorType::get(dimensions, elementType); -} - -/// Parse a dimension list of a tensor or memref type. This populates the -/// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and -/// errors out on `?` otherwise. -/// -/// dimension-list-ranked ::= (dimension `x`)* -/// dimension ::= `?` | decimal-literal -/// -/// When `allowDynamic` is not set, this is used to parse: -/// -/// static-dimension-list ::= (decimal-literal `x`)* -ParseResult -Parser::parseDimensionListRanked(SmallVectorImpl &dimensions, - bool allowDynamic) { - while (getToken().isAny(Token::integer, Token::question)) { - if (consumeIf(Token::question)) { - if (!allowDynamic) - return emitError("expected static shape"); - dimensions.push_back(-1); - } else { - // Hexadecimal integer literals (starting with `0x`) are not allowed in - // aggregate type declarations. Therefore, `0xf32` should be processed as - // a sequence of separate elements `0`, `x`, `f32`. - if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { - // We can get here only if the token is an integer literal. Hexadecimal - // integer literals can only start with `0x` (`1x` wouldn't lex as a - // literal, just `1` would, at which point we don't get into this - // branch). - assert(getTokenSpelling()[0] == '0' && "invalid integer literal"); - dimensions.push_back(0); - state.lex.resetPointer(getTokenSpelling().data() + 1); - consumeToken(); - } else { - // Make sure this integer value is in bound and valid. - auto dimension = getToken().getUnsignedIntegerValue(); - if (!dimension.hasValue()) - return emitError("invalid dimension"); - dimensions.push_back((int64_t)dimension.getValue()); - consumeToken(Token::integer); - } - } - - // Make sure we have an 'x' or something like 'xbf32'. - if (parseXInDimensionList()) - return failure(); - } - - return success(); -} - -/// Parse an 'x' token in a dimension list, handling the case where the x is -/// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next -/// token. -ParseResult Parser::parseXInDimensionList() { - if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x') - return emitError("expected 'x' in dimension list"); - - // If we had a prefix of 'x', lex the next token immediately after the 'x'. - if (getTokenSpelling().size() != 1) - state.lex.resetPointer(getTokenSpelling().data() + 1); - - // Consume the 'x'. - consumeToken(Token::bare_identifier); - - return success(); -} - -// Parse a comma-separated list of dimensions, possibly empty: -// stride-list ::= `[` (dimension (`,` dimension)*)? `]` -ParseResult Parser::parseStrideList(SmallVectorImpl &dimensions) { - if (!consumeIf(Token::l_square)) - return failure(); - // Empty list early exit. - if (consumeIf(Token::r_square)) - return success(); - while (true) { - if (consumeIf(Token::question)) { - dimensions.push_back(MemRefType::getDynamicStrideOrOffset()); - } else { - // This must be an integer value. - int64_t val; - if (getToken().getSpelling().getAsInteger(10, val)) - return emitError("invalid integer value: ") << getToken().getSpelling(); - // Make sure it is not the one value for `?`. - if (ShapedType::isDynamic(val)) - return emitError("invalid integer value: ") - << getToken().getSpelling() - << ", use `?` to specify a dynamic dimension"; - dimensions.push_back(val); - consumeToken(Token::integer); - } - if (!consumeIf(Token::comma)) - break; - } - if (!consumeIf(Token::r_square)) - return failure(); - return success(); -} - -//===----------------------------------------------------------------------===// -// Attribute parsing. -//===----------------------------------------------------------------------===// - -/// Return the symbol reference referred to by the given token, that is known to -/// be an @-identifier. -static std::string extractSymbolReference(Token tok) { - assert(tok.is(Token::at_identifier) && "expected valid @-identifier"); - StringRef nameStr = tok.getSpelling().drop_front(); - - // Check to see if the reference is a string literal, or a bare identifier. - if (nameStr.front() == '"') - return tok.getStringValue(); - return std::string(nameStr); -} - -/// Parse an arbitrary attribute. -/// -/// attribute-value ::= `unit` -/// | bool-literal -/// | integer-literal (`:` (index-type | integer-type))? -/// | float-literal (`:` float-type)? -/// | string-literal (`:` type)? -/// | type -/// | `[` (attribute-value (`,` attribute-value)*)? `]` -/// | `{` (attribute-entry (`,` attribute-entry)*)? `}` -/// | symbol-ref-id (`::` symbol-ref-id)* -/// | `dense` `<` attribute-value `>` `:` -/// (tensor-type | vector-type) -/// | `sparse` `<` attribute-value `,` attribute-value `>` -/// `:` (tensor-type | vector-type) -/// | `opaque` `<` dialect-namespace `,` hex-string-literal -/// `>` `:` (tensor-type | vector-type) -/// | extended-attribute -/// -Attribute Parser::parseAttribute(Type type) { - switch (getToken().getKind()) { - // Parse an AffineMap or IntegerSet attribute. - case Token::kw_affine_map: { - consumeToken(Token::kw_affine_map); - - AffineMap map; - if (parseToken(Token::less, "expected '<' in affine map") || - parseAffineMapReference(map) || - parseToken(Token::greater, "expected '>' in affine map")) - return Attribute(); - return AffineMapAttr::get(map); - } - case Token::kw_affine_set: { - consumeToken(Token::kw_affine_set); - - IntegerSet set; - if (parseToken(Token::less, "expected '<' in integer set") || - parseIntegerSetReference(set) || - parseToken(Token::greater, "expected '>' in integer set")) - return Attribute(); - return IntegerSetAttr::get(set); - } - - // Parse an array attribute. - case Token::l_square: { - consumeToken(Token::l_square); - - SmallVector elements; - auto parseElt = [&]() -> ParseResult { - elements.push_back(parseAttribute()); - return elements.back() ? success() : failure(); - }; - - if (parseCommaSeparatedListUntil(Token::r_square, parseElt)) - return nullptr; - return builder.getArrayAttr(elements); - } - - // Parse a boolean attribute. - case Token::kw_false: - consumeToken(Token::kw_false); - return builder.getBoolAttr(false); - case Token::kw_true: - consumeToken(Token::kw_true); - return builder.getBoolAttr(true); - - // Parse a dense elements attribute. - case Token::kw_dense: - return parseDenseElementsAttr(type); - - // Parse a dictionary attribute. - case Token::l_brace: { - NamedAttrList elements; - if (parseAttributeDict(elements)) - return nullptr; - return elements.getDictionary(getContext()); - } - - // Parse an extended attribute, i.e. alias or dialect attribute. - case Token::hash_identifier: - return parseExtendedAttr(type); - - // Parse floating point and integer attributes. - case Token::floatliteral: - return parseFloatAttr(type, /*isNegative=*/false); - case Token::integer: - return parseDecOrHexAttr(type, /*isNegative=*/false); - case Token::minus: { - consumeToken(Token::minus); - if (getToken().is(Token::integer)) - return parseDecOrHexAttr(type, /*isNegative=*/true); - if (getToken().is(Token::floatliteral)) - return parseFloatAttr(type, /*isNegative=*/true); - - return (emitError("expected constant integer or floating point value"), - nullptr); - } - - // Parse a location attribute. - case Token::kw_loc: { - LocationAttr attr; - return failed(parseLocation(attr)) ? Attribute() : attr; - } - - // Parse an opaque elements attribute. - case Token::kw_opaque: - return parseOpaqueElementsAttr(type); - - // Parse a sparse elements attribute. - case Token::kw_sparse: - return parseSparseElementsAttr(type); - - // Parse a string attribute. - case Token::string: { - auto val = getToken().getStringValue(); - consumeToken(Token::string); - // Parse the optional trailing colon type if one wasn't explicitly provided. - if (!type && consumeIf(Token::colon) && !(type = parseType())) - return Attribute(); - - return type ? StringAttr::get(val, type) - : StringAttr::get(val, getContext()); - } - - // Parse a symbol reference attribute. - case Token::at_identifier: { - std::string nameStr = extractSymbolReference(getToken()); - consumeToken(Token::at_identifier); - - // Parse any nested references. - std::vector nestedRefs; - while (getToken().is(Token::colon)) { - // Check for the '::' prefix. - const char *curPointer = getToken().getLoc().getPointer(); - consumeToken(Token::colon); - if (!consumeIf(Token::colon)) { - state.lex.resetPointer(curPointer); - consumeToken(); - break; - } - // Parse the reference itself. - auto curLoc = getToken().getLoc(); - if (getToken().isNot(Token::at_identifier)) { - emitError(curLoc, "expected nested symbol reference identifier"); - return Attribute(); - } - - std::string nameStr = extractSymbolReference(getToken()); - consumeToken(Token::at_identifier); - nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext())); - } - - return builder.getSymbolRefAttr(nameStr, nestedRefs); - } - - // Parse a 'unit' attribute. - case Token::kw_unit: - consumeToken(Token::kw_unit); - return builder.getUnitAttr(); - - default: - // Parse a type attribute. - if (Type type = parseType()) - return TypeAttr::get(type); - return nullptr; - } -} - -/// Attribute dictionary. -/// -/// attribute-dict ::= `{` `}` -/// | `{` attribute-entry (`,` attribute-entry)* `}` -/// attribute-entry ::= (bare-id | string-literal) `=` attribute-value -/// -ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) { - if (parseToken(Token::l_brace, "expected '{' in attribute dictionary")) - return failure(); - - llvm::SmallDenseSet seenKeys; - auto parseElt = [&]() -> ParseResult { - // The name of an attribute can either be a bare identifier, or a string. - Optional nameId; - if (getToken().is(Token::string)) - nameId = builder.getIdentifier(getToken().getStringValue()); - else if (getToken().isAny(Token::bare_identifier, Token::inttype) || - getToken().isKeyword()) - nameId = builder.getIdentifier(getTokenSpelling()); - else - return emitError("expected attribute name"); - if (!seenKeys.insert(*nameId).second) - return emitError("duplicate key in dictionary attribute"); - consumeToken(); - - // Try to parse the '=' for the attribute value. - if (!consumeIf(Token::equal)) { - // If there is no '=', we treat this as a unit attribute. - attributes.push_back({*nameId, builder.getUnitAttr()}); - return success(); - } - - auto attr = parseAttribute(); - if (!attr) - return failure(); - attributes.push_back({*nameId, attr}); - return success(); - }; - - if (parseCommaSeparatedListUntil(Token::r_brace, parseElt)) - return failure(); - - return success(); -} - -/// Parse an extended attribute. -/// -/// extended-attribute ::= (dialect-attribute | attribute-alias) -/// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>` -/// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body? -/// attribute-alias ::= `#` alias-name -/// -Attribute Parser::parseExtendedAttr(Type type) { - Attribute attr = parseExtendedSymbol( - *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions, - [&](StringRef dialectName, StringRef symbolData, - llvm::SMLoc loc) -> Attribute { - // Parse an optional trailing colon type. - Type attrType = type; - if (consumeIf(Token::colon) && !(attrType = parseType())) - return Attribute(); - - // If we found a registered dialect, then ask it to parse the attribute. - if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { - return parseSymbol( - symbolData, state.context, state.symbols, [&](Parser &parser) { - CustomDialectAsmParser customParser(symbolData, parser); - return dialect->parseAttribute(customParser, attrType); - }); - } - - // Otherwise, form a new opaque attribute. - return OpaqueAttr::getChecked( - Identifier::get(dialectName, state.context), symbolData, - attrType ? attrType : NoneType::get(state.context), - getEncodedSourceLocation(loc)); - }); - - // Ensure that the attribute has the same type as requested. - if (attr && type && attr.getType() != type) { - emitError("attribute type different than expected: expected ") - << type << ", but got " << attr.getType(); - return nullptr; - } - return attr; -} - -/// Parse a float attribute. -Attribute Parser::parseFloatAttr(Type type, bool isNegative) { - auto val = getToken().getFloatingPointValue(); - if (!val.hasValue()) - return (emitError("floating point value too large for attribute"), nullptr); - consumeToken(Token::floatliteral); - if (!type) { - // Default to F64 when no type is specified. - if (!consumeIf(Token::colon)) - type = builder.getF64Type(); - else if (!(type = parseType())) - return nullptr; - } - if (!type.isa()) - return (emitError("floating point value not valid for specified type"), - nullptr); - return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue()); -} - -/// Construct a float attribute bitwise equivalent to the integer literal. -static Optional buildHexadecimalFloatLiteral(Parser *p, FloatType type, - uint64_t value) { - // FIXME: bfloat is currently stored as a double internally because it doesn't - // have valid APFloat semantics. - if (type.isF64() || type.isBF16()) - return APFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value)); - - APInt apInt(type.getWidth(), value); - if (apInt != value) { - p->emitError("hexadecimal float constant out of range for type"); - return llvm::None; - } - return APFloat(type.getFloatSemantics(), apInt); -} - -/// Construct an APint from a parsed value, a known attribute type and -/// sign. -static Optional buildAttributeAPInt(Type type, bool isNegative, - StringRef spelling) { - // Parse the integer value into an APInt that is big enough to hold the value. - APInt result; - bool isHex = spelling.size() > 1 && spelling[1] == 'x'; - if (spelling.getAsInteger(isHex ? 0 : 10, result)) - return llvm::None; - - // Extend or truncate the bitwidth to the right size. - unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth - : type.getIntOrFloatBitWidth(); - if (width > result.getBitWidth()) { - result = result.zext(width); - } else if (width < result.getBitWidth()) { - // The parser can return an unnecessarily wide result with leading zeros. - // This isn't a problem, but truncating off bits is bad. - if (result.countLeadingZeros() < result.getBitWidth() - width) - return llvm::None; - - result = result.trunc(width); - } - - if (isNegative) { - // The value is negative, we have an overflow if the sign bit is not set - // in the negated apInt. - result.negate(); - if (!result.isSignBitSet()) - return llvm::None; - } else if ((type.isSignedInteger() || type.isIndex()) && - result.isSignBitSet()) { - // The value is a positive signed integer or index, - // we have an overflow if the sign bit is set. - return llvm::None; - } - - return result; -} - -/// Parse a decimal or a hexadecimal literal, which can be either an integer -/// or a float attribute. -Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { - // Remember if the literal is hexadecimal. - StringRef spelling = getToken().getSpelling(); - auto loc = state.curToken.getLoc(); - bool isHex = spelling.size() > 1 && spelling[1] == 'x'; - - consumeToken(Token::integer); - if (!type) { - // Default to i64 if not type is specified. - if (!consumeIf(Token::colon)) - type = builder.getIntegerType(64); - else if (!(type = parseType())) - return nullptr; - } - - if (auto floatType = type.dyn_cast()) { - if (isNegative) - return emitError( - loc, - "hexadecimal float literal should not have a leading minus"), - nullptr; - if (!isHex) { - emitError(loc, "unexpected decimal integer literal for a float attribute") - .attachNote() - << "add a trailing dot to make the literal a float"; - return nullptr; - } - - auto val = Token::getUInt64IntegerValue(spelling); - if (!val.hasValue()) - return emitError("integer constant out of range for attribute"), nullptr; - - // Construct a float attribute bitwise equivalent to the integer literal. - Optional apVal = - buildHexadecimalFloatLiteral(this, floatType, *val); - return apVal ? FloatAttr::get(floatType, *apVal) : Attribute(); - } - - if (!type.isa() && !type.isa()) - return emitError(loc, "integer literal not valid for specified type"), - nullptr; - - if (isNegative && type.isUnsignedInteger()) { - emitError(loc, - "negative integer literal not valid for unsigned integer type"); - return nullptr; - } - - Optional apInt = buildAttributeAPInt(type, isNegative, spelling); - if (!apInt) - return emitError(loc, "integer constant out of range for attribute"), - nullptr; - return builder.getIntegerAttr(type, *apInt); -} - -/// Parse elements values stored within a hex etring. On success, the values are -/// stored into 'result'. -static ParseResult parseElementAttrHexValues(Parser &parser, Token tok, - std::string &result) { - std::string val = tok.getStringValue(); - if (val.size() < 2 || val[0] != '0' || val[1] != 'x') - return parser.emitError(tok.getLoc(), - "elements hex string should start with '0x'"); - - StringRef hexValues = StringRef(val).drop_front(2); - if (!llvm::all_of(hexValues, llvm::isHexDigit)) - return parser.emitError(tok.getLoc(), - "elements hex string only contains hex digits"); - - result = llvm::fromHex(hexValues); - return success(); -} - -/// Parse an opaque elements attribute. -Attribute Parser::parseOpaqueElementsAttr(Type attrType) { - consumeToken(Token::kw_opaque); - if (parseToken(Token::less, "expected '<' after 'opaque'")) - return nullptr; - - if (getToken().isNot(Token::string)) - return (emitError("expected dialect namespace"), nullptr); - - auto name = getToken().getStringValue(); - auto *dialect = builder.getContext()->getRegisteredDialect(name); - // TODO(shpeisman): Allow for having an unknown dialect on an opaque - // attribute. Otherwise, it can't be roundtripped without having the dialect - // registered. - if (!dialect) - return (emitError("no registered dialect with namespace '" + name + "'"), - nullptr); - consumeToken(Token::string); - - if (parseToken(Token::comma, "expected ','")) - return nullptr; - - Token hexTok = getToken(); - if (parseToken(Token::string, "elements hex string should start with '0x'") || - parseToken(Token::greater, "expected '>'")) - return nullptr; - auto type = parseElementsLiteralType(attrType); - if (!type) - return nullptr; - - std::string data; - if (parseElementAttrHexValues(*this, hexTok, data)) - return nullptr; - return OpaqueElementsAttr::get(dialect, type, data); -} - -namespace { -class TensorLiteralParser { -public: - TensorLiteralParser(Parser &p) : p(p) {} - - /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser - /// may also parse a tensor literal that is store as a hex string. - ParseResult parse(bool allowHex); - - /// Build a dense attribute instance with the parsed elements and the given - /// shaped type. - DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type); - - ArrayRef getShape() const { return shape; } - -private: - /// Get the parsed elements for an integer attribute. - ParseResult getIntAttrElements(llvm::SMLoc loc, Type eltTy, - std::vector &intValues); - - /// Get the parsed elements for a float attribute. - ParseResult getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy, - std::vector &floatValues); - - /// Build a Dense String attribute for the given type. - DenseElementsAttr getStringAttr(llvm::SMLoc loc, ShapedType type, Type eltTy); - - /// Build a Dense attribute with hex data for the given type. - DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type); - - /// Parse a single element, returning failure if it isn't a valid element - /// literal. For example: - /// parseElement(1) -> Success, 1 - /// parseElement([1]) -> Failure - ParseResult parseElement(); - - /// Parse a list of either lists or elements, returning the dimensions of the - /// parsed sub-tensors in dims. For example: - /// parseList([1, 2, 3]) -> Success, [3] - /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] - /// parseList([[1, 2], 3]) -> Failure - /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure - ParseResult parseList(SmallVectorImpl &dims); - - /// Parse a literal that was printed as a hex string. - ParseResult parseHexElements(); - - Parser &p; - - /// The shape inferred from the parsed elements. - SmallVector shape; - - /// Storage used when parsing elements, this is a pair of . - std::vector> storage; - - /// Storage used when parsing elements that were stored as hex values. - Optional hexStorage; -}; -} // namespace - -/// Parse the elements of a tensor literal. If 'allowHex' is true, the parser -/// may also parse a tensor literal that is store as a hex string. -ParseResult TensorLiteralParser::parse(bool allowHex) { - // If hex is allowed, check for a string literal. - if (allowHex && p.getToken().is(Token::string)) { - hexStorage = p.getToken(); - p.consumeToken(Token::string); - return success(); - } - // Otherwise, parse a list or an individual element. - if (p.getToken().is(Token::l_square)) - return parseList(shape); - return parseElement(); -} - -/// Build a dense attribute instance with the parsed elements and the given -/// shaped type. -DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc, - ShapedType type) { - Type eltType = type.getElementType(); - - // Check to see if we parse the literal from a hex string. - if (hexStorage.hasValue() && - (eltType.isIntOrFloat() || eltType.isa())) - return getHexAttr(loc, type); - - // Check that the parsed storage size has the same number of elements to the - // type, or is a known splat. - if (!shape.empty() && getShape() != type.getShape()) { - p.emitError(loc) << "inferred shape of elements literal ([" << getShape() - << "]) does not match type ([" << type.getShape() << "])"; - return nullptr; - } - - // Handle complex types in the specific element type cases below. - bool isComplex = false; - if (ComplexType complexTy = eltType.dyn_cast()) { - eltType = complexTy.getElementType(); - isComplex = true; - } - - // Handle integer and index types. - if (eltType.isIntOrIndex()) { - std::vector intValues; - if (failed(getIntAttrElements(loc, eltType, intValues))) - return nullptr; - if (isComplex) { - // If this is a complex, treat the parsed values as complex values. - auto complexData = llvm::makeArrayRef( - reinterpret_cast *>(intValues.data()), - intValues.size() / 2); - return DenseElementsAttr::get(type, complexData); - } - return DenseElementsAttr::get(type, intValues); - } - // Handle floating point types. - if (FloatType floatTy = eltType.dyn_cast()) { - std::vector floatValues; - if (failed(getFloatAttrElements(loc, floatTy, floatValues))) - return nullptr; - if (isComplex) { - // If this is a complex, treat the parsed values as complex values. - auto complexData = llvm::makeArrayRef( - reinterpret_cast *>(floatValues.data()), - floatValues.size() / 2); - return DenseElementsAttr::get(type, complexData); - } - return DenseElementsAttr::get(type, floatValues); - } - - // Other types are assumed to be string representations. - return getStringAttr(loc, type, type.getElementType()); -} - -/// Build a Dense Integer attribute for the given type. -ParseResult -TensorLiteralParser::getIntAttrElements(llvm::SMLoc loc, Type eltTy, - std::vector &intValues) { - intValues.reserve(storage.size()); - bool isUintType = eltTy.isUnsignedInteger(); - for (const auto &signAndToken : storage) { - bool isNegative = signAndToken.first; - const Token &token = signAndToken.second; - auto tokenLoc = token.getLoc(); - - if (isNegative && isUintType) { - return p.emitError(tokenLoc) - << "expected unsigned integer elements, but parsed negative value"; - } - - // Check to see if floating point values were parsed. - if (token.is(Token::floatliteral)) { - return p.emitError(tokenLoc) - << "expected integer elements, but parsed floating-point"; - } - - assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && - "unexpected token type"); - if (token.isAny(Token::kw_true, Token::kw_false)) { - if (!eltTy.isInteger(1)) { - return p.emitError(tokenLoc) - << "expected i1 type for 'true' or 'false' values"; - } - APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false); - intValues.push_back(apInt); - continue; - } - - // Create APInt values for each element with the correct bitwidth. - Optional apInt = - buildAttributeAPInt(eltTy, isNegative, token.getSpelling()); - if (!apInt) - return p.emitError(tokenLoc, "integer constant out of range for type"); - intValues.push_back(*apInt); - } - return success(); -} - -/// Build a Dense Float attribute for the given type. -ParseResult -TensorLiteralParser::getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy, - std::vector &floatValues) { - floatValues.reserve(storage.size()); - for (const auto &signAndToken : storage) { - bool isNegative = signAndToken.first; - const Token &token = signAndToken.second; - - // Handle hexadecimal float literals. - if (token.is(Token::integer) && token.getSpelling().startswith("0x")) { - if (isNegative) { - return p.emitError(token.getLoc()) - << "hexadecimal float literal should not have a leading minus"; - } - auto val = token.getUInt64IntegerValue(); - if (!val.hasValue()) { - return p.emitError( - "hexadecimal float constant out of range for attribute"); - } - Optional apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val); - if (!apVal) - return failure(); - floatValues.push_back(*apVal); - continue; - } - - // Check to see if any decimal integers or booleans were parsed. - if (!token.is(Token::floatliteral)) - return p.emitError() - << "expected floating-point elements, but parsed integer"; - - // Build the float values from tokens. - auto val = token.getFloatingPointValue(); - if (!val.hasValue()) - return p.emitError("floating point value too large for attribute"); - - // Treat BF16 as double because it is not supported in LLVM's APFloat. - APFloat apVal(isNegative ? -*val : *val); - if (!eltTy.isBF16() && !eltTy.isF64()) { - bool unused; - apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven, - &unused); - } - floatValues.push_back(apVal); - } - return success(); -} - -/// Build a Dense String attribute for the given type. -DenseElementsAttr TensorLiteralParser::getStringAttr(llvm::SMLoc loc, - ShapedType type, - Type eltTy) { - if (hexStorage.hasValue()) { - auto stringValue = hexStorage.getValue().getStringValue(); - return DenseStringElementsAttr::get(type, {stringValue}); - } - - std::vector stringValues; - std::vector stringRefValues; - stringValues.reserve(storage.size()); - stringRefValues.reserve(storage.size()); - - for (auto val : storage) { - stringValues.push_back(val.second.getStringValue()); - stringRefValues.push_back(stringValues.back()); - } - - return DenseStringElementsAttr::get(type, stringRefValues); -} - -/// Build a Dense attribute with hex data for the given type. -DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc, - ShapedType type) { - Type elementType = type.getElementType(); - if (!elementType.isIntOrIndexOrFloat() && !elementType.isa()) { - p.emitError(loc) - << "expected floating-point, integer, or complex element type, got " - << elementType; - return nullptr; - } - - std::string data; - if (parseElementAttrHexValues(p, hexStorage.getValue(), data)) - return nullptr; - - ArrayRef rawData(data.data(), data.size()); - bool detectedSplat = false; - if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) { - p.emitError(loc) << "elements hex data size is invalid for provided type: " - << type; - return nullptr; - } - - return DenseElementsAttr::getFromRawBuffer(type, rawData, detectedSplat); -} - -ParseResult TensorLiteralParser::parseElement() { - switch (p.getToken().getKind()) { - // Parse a boolean element. - case Token::kw_true: - case Token::kw_false: - case Token::floatliteral: - case Token::integer: - storage.emplace_back(/*isNegative=*/false, p.getToken()); - p.consumeToken(); - break; - - // Parse a signed integer or a negative floating-point element. - case Token::minus: - p.consumeToken(Token::minus); - if (!p.getToken().isAny(Token::floatliteral, Token::integer)) - return p.emitError("expected integer or floating point literal"); - storage.emplace_back(/*isNegative=*/true, p.getToken()); - p.consumeToken(); - break; - - case Token::string: - storage.emplace_back(/*isNegative=*/ false, p.getToken()); - p.consumeToken(); - break; - - // Parse a complex element of the form '(' element ',' element ')'. - case Token::l_paren: - p.consumeToken(Token::l_paren); - if (parseElement() || - p.parseToken(Token::comma, "expected ',' between complex elements") || - parseElement() || - p.parseToken(Token::r_paren, "expected ')' after complex elements")) - return failure(); - break; - - default: - return p.emitError("expected element literal of primitive type"); - } - - return success(); -} - -/// Parse a list of either lists or elements, returning the dimensions of the -/// parsed sub-tensors in dims. For example: -/// parseList([1, 2, 3]) -> Success, [3] -/// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] -/// parseList([[1, 2], 3]) -> Failure -/// parseList([[1, [2, 3]], [4, [5]]]) -> Failure -ParseResult TensorLiteralParser::parseList(SmallVectorImpl &dims) { - p.consumeToken(Token::l_square); - - auto checkDims = [&](const SmallVectorImpl &prevDims, - const SmallVectorImpl &newDims) -> ParseResult { - if (prevDims == newDims) - return success(); - return p.emitError("tensor literal is invalid; ranks are not consistent " - "between elements"); - }; - - bool first = true; - SmallVector newDims; - unsigned size = 0; - auto parseCommaSeparatedList = [&]() -> ParseResult { - SmallVector thisDims; - if (p.getToken().getKind() == Token::l_square) { - if (parseList(thisDims)) - return failure(); - } else if (parseElement()) { - return failure(); - } - ++size; - if (!first) - return checkDims(newDims, thisDims); - newDims = thisDims; - first = false; - return success(); - }; - if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList)) - return failure(); - - // Return the sublists' dimensions with 'size' prepended. - dims.clear(); - dims.push_back(size); - dims.append(newDims.begin(), newDims.end()); - return success(); -} - -/// Parse a dense elements attribute. -Attribute Parser::parseDenseElementsAttr(Type attrType) { - consumeToken(Token::kw_dense); - if (parseToken(Token::less, "expected '<' after 'dense'")) - return nullptr; - - // Parse the literal data. - TensorLiteralParser literalParser(*this); - if (literalParser.parse(/*allowHex=*/true)) - return nullptr; - - if (parseToken(Token::greater, "expected '>'")) - return nullptr; - - auto typeLoc = getToken().getLoc(); - auto type = parseElementsLiteralType(attrType); - if (!type) - return nullptr; - return literalParser.getAttr(typeLoc, type); -} - -/// Shaped type for elements attribute. -/// -/// elements-literal-type ::= vector-type | ranked-tensor-type -/// -/// This method also checks the type has static shape. -ShapedType Parser::parseElementsLiteralType(Type type) { - // If the user didn't provide a type, parse the colon type for the literal. - if (!type) { - if (parseToken(Token::colon, "expected ':'")) - return nullptr; - if (!(type = parseType())) - return nullptr; - } - - if (!type.isa() && !type.isa()) { - emitError("elements literal must be a ranked tensor or vector type"); - return nullptr; - } - - auto sType = type.cast(); - if (!sType.hasStaticShape()) - return (emitError("elements literal type must have static shape"), nullptr); - - return sType; -} - -/// Parse a sparse elements attribute. -Attribute Parser::parseSparseElementsAttr(Type attrType) { - consumeToken(Token::kw_sparse); - if (parseToken(Token::less, "Expected '<' after 'sparse'")) - return nullptr; - - /// Parse the indices. We don't allow hex values here as we may need to use - /// the inferred shape. - auto indicesLoc = getToken().getLoc(); - TensorLiteralParser indiceParser(*this); - if (indiceParser.parse(/*allowHex=*/false)) - return nullptr; - - if (parseToken(Token::comma, "expected ','")) - return nullptr; - - /// Parse the values. - auto valuesLoc = getToken().getLoc(); - TensorLiteralParser valuesParser(*this); - if (valuesParser.parse(/*allowHex=*/true)) - return nullptr; - - if (parseToken(Token::greater, "expected '>'")) - return nullptr; - - auto type = parseElementsLiteralType(attrType); - if (!type) - return nullptr; - - // If the indices are a splat, i.e. the literal parser parsed an element and - // not a list, we set the shape explicitly. The indices are represented by a - // 2-dimensional shape where the second dimension is the rank of the type. - // Given that the parsed indices is a splat, we know that we only have one - // indice and thus one for the first dimension. - auto indiceEltType = builder.getIntegerType(64); - ShapedType indicesType; - if (indiceParser.getShape().empty()) { - indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); - } else { - // Otherwise, set the shape to the one parsed by the literal parser. - indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); - } - auto indices = indiceParser.getAttr(indicesLoc, indicesType); - - // If the values are a splat, set the shape explicitly based on the number of - // indices. The number of indices is encoded in the first dimension of the - // indice shape type. - auto valuesEltType = type.getElementType(); - ShapedType valuesType = - valuesParser.getShape().empty() - ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) - : RankedTensorType::get(valuesParser.getShape(), valuesEltType); - auto values = valuesParser.getAttr(valuesLoc, valuesType); - - /// Sanity check. - if (valuesType.getRank() != 1) - return (emitError("expected 1-d tensor for values"), nullptr); - - auto sameShape = (indicesType.getRank() == 1) || - (type.getRank() == indicesType.getDimSize(1)); - auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0); - if (!sameShape || !sameElementNum) { - emitError() << "expected shape ([" << type.getShape() - << "]); inferred shape of indices literal ([" - << indicesType.getShape() - << "]); inferred shape of values literal ([" - << valuesType.getShape() << "])"; - return nullptr; - } - - // Build the sparse elements attribute by the indices and values. - return SparseElementsAttr::get(type, indices, values); -} - -//===----------------------------------------------------------------------===// -// Location parsing. -//===----------------------------------------------------------------------===// - -/// Parse a location. -/// -/// location ::= `loc` inline-location -/// inline-location ::= '(' location-inst ')' -/// -ParseResult Parser::parseLocation(LocationAttr &loc) { - // Check for 'loc' identifier. - if (parseToken(Token::kw_loc, "expected 'loc' keyword")) - return emitError(); - - // Parse the inline-location. - if (parseToken(Token::l_paren, "expected '(' in inline location") || - parseLocationInstance(loc) || - parseToken(Token::r_paren, "expected ')' in inline location")) - return failure(); - return success(); -} - -/// Specific location instances. -/// -/// location-inst ::= filelinecol-location | -/// name-location | -/// callsite-location | -/// fused-location | -/// unknown-location -/// filelinecol-location ::= string-literal ':' integer-literal -/// ':' integer-literal -/// name-location ::= string-literal -/// callsite-location ::= 'callsite' '(' location-inst 'at' location-inst ')' -/// fused-location ::= fused ('<' attribute-value '>')? -/// '[' location-inst (location-inst ',')* ']' -/// unknown-location ::= 'unknown' -/// -ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) { - consumeToken(Token::bare_identifier); - - // Parse the '('. - if (parseToken(Token::l_paren, "expected '(' in callsite location")) - return failure(); - - // Parse the callee location. - LocationAttr calleeLoc; - if (parseLocationInstance(calleeLoc)) - return failure(); - - // Parse the 'at'. - if (getToken().isNot(Token::bare_identifier) || - getToken().getSpelling() != "at") - return emitError("expected 'at' in callsite location"); - consumeToken(Token::bare_identifier); - - // Parse the caller location. - LocationAttr callerLoc; - if (parseLocationInstance(callerLoc)) - return failure(); - - // Parse the ')'. - if (parseToken(Token::r_paren, "expected ')' in callsite location")) - return failure(); - - // Return the callsite location. - loc = CallSiteLoc::get(calleeLoc, callerLoc); - return success(); -} - -ParseResult Parser::parseFusedLocation(LocationAttr &loc) { - consumeToken(Token::bare_identifier); - - // Try to parse the optional metadata. - Attribute metadata; - if (consumeIf(Token::less)) { - metadata = parseAttribute(); - if (!metadata) - return emitError("expected valid attribute metadata"); - // Parse the '>' token. - if (parseToken(Token::greater, - "expected '>' after fused location metadata")) - return failure(); - } - - SmallVector locations; - auto parseElt = [&] { - LocationAttr newLoc; - if (parseLocationInstance(newLoc)) - return failure(); - locations.push_back(newLoc); - return success(); - }; - - if (parseToken(Token::l_square, "expected '[' in fused location") || - parseCommaSeparatedList(parseElt) || - parseToken(Token::r_square, "expected ']' in fused location")) - return failure(); - - // Return the fused location. - loc = FusedLoc::get(locations, metadata, getContext()); - return success(); -} - -ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) { - auto *ctx = getContext(); - auto str = getToken().getStringValue(); - consumeToken(Token::string); - - // If the next token is ':' this is a filelinecol location. - if (consumeIf(Token::colon)) { - // Parse the line number. - if (getToken().isNot(Token::integer)) - return emitError("expected integer line number in FileLineColLoc"); - auto line = getToken().getUnsignedIntegerValue(); - if (!line.hasValue()) - return emitError("expected integer line number in FileLineColLoc"); - consumeToken(Token::integer); - - // Parse the ':'. - if (parseToken(Token::colon, "expected ':' in FileLineColLoc")) - return failure(); - - // Parse the column number. - if (getToken().isNot(Token::integer)) - return emitError("expected integer column number in FileLineColLoc"); - auto column = getToken().getUnsignedIntegerValue(); - if (!column.hasValue()) - return emitError("expected integer column number in FileLineColLoc"); - consumeToken(Token::integer); - - loc = FileLineColLoc::get(str, line.getValue(), column.getValue(), ctx); - return success(); - } - - // Otherwise, this is a NameLoc. - - // Check for a child location. - if (consumeIf(Token::l_paren)) { - auto childSourceLoc = getToken().getLoc(); - - // Parse the child location. - LocationAttr childLoc; - if (parseLocationInstance(childLoc)) - return failure(); - - // The child must not be another NameLoc. - if (childLoc.isa()) - return emitError(childSourceLoc, - "child of NameLoc cannot be another NameLoc"); - loc = NameLoc::get(Identifier::get(str, ctx), childLoc); - - // Parse the closing ')'. - if (parseToken(Token::r_paren, - "expected ')' after child location of NameLoc")) - return failure(); - } else { - loc = NameLoc::get(Identifier::get(str, ctx), ctx); - } - - return success(); -} - -ParseResult Parser::parseLocationInstance(LocationAttr &loc) { - // Handle either name or filelinecol locations. - if (getToken().is(Token::string)) - return parseNameOrFileLineColLocation(loc); - - // Bare tokens required for other cases. - if (!getToken().is(Token::bare_identifier)) - return emitError("expected location instance"); - - // Check for the 'callsite' signifying a callsite location. - if (getToken().getSpelling() == "callsite") - return parseCallSiteLocation(loc); - - // If the token is 'fused', then this is a fused location. - if (getToken().getSpelling() == "fused") - return parseFusedLocation(loc); - - // Check for a 'unknown' for an unknown location. - if (getToken().getSpelling() == "unknown") { - consumeToken(Token::bare_identifier); - loc = UnknownLoc::get(getContext()); - return success(); - } +#include "Parser.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Parser.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/ADT/bit.h" +#include "llvm/Support/PrettyStackTrace.h" +#include "llvm/Support/SourceMgr.h" +#include - return emitError("expected location instance"); -} +using namespace mlir; +using namespace mlir::detail; +using llvm::MemoryBuffer; +using llvm::SMLoc; +using llvm::SourceMgr; //===----------------------------------------------------------------------===// -// Affine parsing. +// Parser //===----------------------------------------------------------------------===// -/// Lower precedence ops (all at the same precedence level). LNoOp is false in -/// the boolean sense. -enum AffineLowPrecOp { - /// Null value. - LNoOp, - Add, - Sub -}; - -/// Higher precedence ops - all at the same precedence level. HNoOp is false -/// in the boolean sense. -enum AffineHighPrecOp { - /// Null value. - HNoOp, - Mul, - FloorDiv, - CeilDiv, - Mod -}; - -namespace { -/// This is a specialized parser for affine structures (affine maps, affine -/// expressions, and integer sets), maintaining the state transient to their -/// bodies. -class AffineParser : public Parser { -public: - AffineParser(ParserState &state, bool allowParsingSSAIds = false, - function_ref parseElement = nullptr) - : Parser(state), allowParsingSSAIds(allowParsingSSAIds), - parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {} - - AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols); - ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set); - IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols); - ParseResult parseAffineMapOfSSAIds(AffineMap &map, - OpAsmParser::Delimiter delimiter); - void getDimsAndSymbolSSAIds(SmallVectorImpl &dimAndSymbolSSAIds, - unsigned &numDims); - -private: - // Binary affine op parsing. - AffineLowPrecOp consumeIfLowPrecOp(); - AffineHighPrecOp consumeIfHighPrecOp(); - - // Identifier lists for polyhedral structures. - ParseResult parseDimIdList(unsigned &numDims); - ParseResult parseSymbolIdList(unsigned &numSymbols); - ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims, - unsigned &numSymbols); - ParseResult parseIdentifierDefinition(AffineExpr idExpr); - - AffineExpr parseAffineExpr(); - AffineExpr parseParentheticalExpr(); - AffineExpr parseNegateExpression(AffineExpr lhs); - AffineExpr parseIntegerExpr(); - AffineExpr parseBareIdExpr(); - AffineExpr parseSSAIdExpr(bool isSymbol); - AffineExpr parseSymbolSSAIdExpr(); - - AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs, - AffineExpr rhs, SMLoc opLoc); - AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs, - AffineExpr rhs); - AffineExpr parseAffineOperandExpr(AffineExpr lhs); - AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp); - AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp, - SMLoc llhsOpLoc); - AffineExpr parseAffineConstraint(bool *isEq); - -private: - bool allowParsingSSAIds; - function_ref parseElement; - unsigned numDimOperands; - unsigned numSymbolOperands; - SmallVector, 4> dimsAndSymbols; -}; -} // end anonymous namespace - -/// Create an affine binary high precedence op expression (mul's, div's, mod). -/// opLoc is the location of the op token to be used to report errors -/// for non-conforming expressions. -AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op, - AffineExpr lhs, AffineExpr rhs, - SMLoc opLoc) { - // TODO: make the error location info accurate. - switch (op) { - case Mul: - if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) { - emitError(opLoc, "non-affine expression: at least one of the multiply " - "operands has to be either a constant or symbolic"); - return nullptr; - } - return lhs * rhs; - case FloorDiv: - if (!rhs.isSymbolicOrConstant()) { - emitError(opLoc, "non-affine expression: right operand of floordiv " - "has to be either a constant or symbolic"); - return nullptr; - } - return lhs.floorDiv(rhs); - case CeilDiv: - if (!rhs.isSymbolicOrConstant()) { - emitError(opLoc, "non-affine expression: right operand of ceildiv " - "has to be either a constant or symbolic"); - return nullptr; - } - return lhs.ceilDiv(rhs); - case Mod: - if (!rhs.isSymbolicOrConstant()) { - emitError(opLoc, "non-affine expression: right operand of mod " - "has to be either a constant or symbolic"); - return nullptr; - } - return lhs % rhs; - case HNoOp: - llvm_unreachable("can't create affine expression for null high prec op"); - return nullptr; - } - llvm_unreachable("Unknown AffineHighPrecOp"); -} - -/// Create an affine binary low precedence op expression (add, sub). -AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op, - AffineExpr lhs, AffineExpr rhs) { - switch (op) { - case AffineLowPrecOp::Add: - return lhs + rhs; - case AffineLowPrecOp::Sub: - return lhs - rhs; - case AffineLowPrecOp::LNoOp: - llvm_unreachable("can't create affine expression for null low prec op"); - return nullptr; - } - llvm_unreachable("Unknown AffineLowPrecOp"); -} - -/// Consume this token if it is a lower precedence affine op (there are only -/// two precedence levels). -AffineLowPrecOp AffineParser::consumeIfLowPrecOp() { - switch (getToken().getKind()) { - case Token::plus: - consumeToken(Token::plus); - return AffineLowPrecOp::Add; - case Token::minus: - consumeToken(Token::minus); - return AffineLowPrecOp::Sub; - default: - return AffineLowPrecOp::LNoOp; - } -} - -/// Consume this token if it is a higher precedence affine op (there are only -/// two precedence levels) -AffineHighPrecOp AffineParser::consumeIfHighPrecOp() { - switch (getToken().getKind()) { - case Token::star: - consumeToken(Token::star); - return Mul; - case Token::kw_floordiv: - consumeToken(Token::kw_floordiv); - return FloorDiv; - case Token::kw_ceildiv: - consumeToken(Token::kw_ceildiv); - return CeilDiv; - case Token::kw_mod: - consumeToken(Token::kw_mod); - return Mod; - default: - return HNoOp; - } -} - -/// Parse a high precedence op expression list: mul, div, and mod are high -/// precedence binary ops, i.e., parse a -/// expr_1 op_1 expr_2 op_2 ... expr_n -/// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod). -/// All affine binary ops are left associative. -/// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is -/// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is -/// null. llhsOpLoc is the location of the llhsOp token that will be used to -/// report an error for non-conforming expressions. -AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs, - AffineHighPrecOp llhsOp, - SMLoc llhsOpLoc) { - AffineExpr lhs = parseAffineOperandExpr(llhs); - if (!lhs) - return nullptr; - - // Found an LHS. Parse the remaining expression. - auto opLoc = getToken().getLoc(); - if (AffineHighPrecOp op = consumeIfHighPrecOp()) { - if (llhs) { - AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc); - if (!expr) - return nullptr; - return parseAffineHighPrecOpExpr(expr, op, opLoc); - } - // No LLHS, get RHS - return parseAffineHighPrecOpExpr(lhs, op, opLoc); - } - - // This is the last operand in this expression. - if (llhs) - return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc); - - // No llhs, 'lhs' itself is the expression. - return lhs; -} - -/// Parse an affine expression inside parentheses. -/// -/// affine-expr ::= `(` affine-expr `)` -AffineExpr AffineParser::parseParentheticalExpr() { - if (parseToken(Token::l_paren, "expected '('")) - return nullptr; - if (getToken().is(Token::r_paren)) - return (emitError("no expression inside parentheses"), nullptr); - - auto expr = parseAffineExpr(); - if (!expr) - return nullptr; - if (parseToken(Token::r_paren, "expected ')'")) - return nullptr; - - return expr; -} - -/// Parse the negation expression. -/// -/// affine-expr ::= `-` affine-expr -AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) { - if (parseToken(Token::minus, "expected '-'")) - return nullptr; - - AffineExpr operand = parseAffineOperandExpr(lhs); - // Since negation has the highest precedence of all ops (including high - // precedence ops) but lower than parentheses, we are only going to use - // parseAffineOperandExpr instead of parseAffineExpr here. - if (!operand) - // Extra error message although parseAffineOperandExpr would have - // complained. Leads to a better diagnostic. - return (emitError("missing operand of negation"), nullptr); - return (-1) * operand; -} - -/// Parse a bare id that may appear in an affine expression. -/// -/// affine-expr ::= bare-id -AffineExpr AffineParser::parseBareIdExpr() { - if (getToken().isNot(Token::bare_identifier)) - return (emitError("expected bare identifier"), nullptr); - - StringRef sRef = getTokenSpelling(); - for (auto entry : dimsAndSymbols) { - if (entry.first == sRef) { - consumeToken(Token::bare_identifier); - return entry.second; - } - } - - return (emitError("use of undeclared identifier"), nullptr); -} - -/// Parse an SSA id which may appear in an affine expression. -AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) { - if (!allowParsingSSAIds) - return (emitError("unexpected ssa identifier"), nullptr); - if (getToken().isNot(Token::percent_identifier)) - return (emitError("expected ssa identifier"), nullptr); - auto name = getTokenSpelling(); - // Check if we already parsed this SSA id. - for (auto entry : dimsAndSymbols) { - if (entry.first == name) { - consumeToken(Token::percent_identifier); - return entry.second; - } - } - // Parse the SSA id and add an AffineDim/SymbolExpr to represent it. - if (parseElement(isSymbol)) - return (emitError("failed to parse ssa identifier"), nullptr); - auto idExpr = isSymbol - ? getAffineSymbolExpr(numSymbolOperands++, getContext()) - : getAffineDimExpr(numDimOperands++, getContext()); - dimsAndSymbols.push_back({name, idExpr}); - return idExpr; -} - -AffineExpr AffineParser::parseSymbolSSAIdExpr() { - if (parseToken(Token::kw_symbol, "expected symbol keyword") || - parseToken(Token::l_paren, "expected '(' at start of SSA symbol")) - return nullptr; - AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true); - if (!symbolExpr) - return nullptr; - if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol")) - return nullptr; - return symbolExpr; -} - -/// Parse a positive integral constant appearing in an affine expression. -/// -/// affine-expr ::= integer-literal -AffineExpr AffineParser::parseIntegerExpr() { - auto val = getToken().getUInt64IntegerValue(); - if (!val.hasValue() || (int64_t)val.getValue() < 0) - return (emitError("constant too large for index"), nullptr); - - consumeToken(Token::integer); - return builder.getAffineConstantExpr((int64_t)val.getValue()); -} - -/// Parses an expression that can be a valid operand of an affine expression. -/// lhs: if non-null, lhs is an affine expression that is the lhs of a binary -/// operator, the rhs of which is being parsed. This is used to determine -/// whether an error should be emitted for a missing right operand. -// Eg: for an expression without parentheses (like i + j + k + l), each -// of the four identifiers is an operand. For i + j*k + l, j*k is not an -// operand expression, it's an op expression and will be parsed via -// parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and -// -l are valid operands that will be parsed by this function. -AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { - switch (getToken().getKind()) { - case Token::bare_identifier: - return parseBareIdExpr(); - case Token::kw_symbol: - return parseSymbolSSAIdExpr(); - case Token::percent_identifier: - return parseSSAIdExpr(/*isSymbol=*/false); - case Token::integer: - return parseIntegerExpr(); - case Token::l_paren: - return parseParentheticalExpr(); - case Token::minus: - return parseNegateExpression(lhs); - case Token::kw_ceildiv: - case Token::kw_floordiv: - case Token::kw_mod: - case Token::plus: - case Token::star: - if (lhs) - emitError("missing right operand of binary operator"); - else - emitError("missing left operand of binary operator"); - return nullptr; - default: - if (lhs) - emitError("missing right operand of binary operator"); - else - emitError("expected affine expression"); - return nullptr; - } -} - -/// Parse affine expressions that are bare-id's, integer constants, -/// parenthetical affine expressions, and affine op expressions that are a -/// composition of those. -/// -/// All binary op's associate from left to right. -/// -/// {add, sub} have lower precedence than {mul, div, and mod}. -/// -/// Add, sub'are themselves at the same precedence level. Mul, floordiv, -/// ceildiv, and mod are at the same higher precedence level. Negation has -/// higher precedence than any binary op. -/// -/// llhs: the affine expression appearing on the left of the one being parsed. -/// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null, -/// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned -/// if llhs is non-null; otherwise lhs is returned. This is to deal with left -/// associativity. -/// -/// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function -/// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where -/// (e2*e3) will be parsed using parseAffineHighPrecOpExpr(). -AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs, - AffineLowPrecOp llhsOp) { - AffineExpr lhs; - if (!(lhs = parseAffineOperandExpr(llhs))) - return nullptr; - - // Found an LHS. Deal with the ops. - if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) { - if (llhs) { - AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs); - return parseAffineLowPrecOpExpr(sum, lOp); - } - // No LLHS, get RHS and form the expression. - return parseAffineLowPrecOpExpr(lhs, lOp); - } - auto opLoc = getToken().getLoc(); - if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) { - // We have a higher precedence op here. Get the rhs operand for the llhs - // through parseAffineHighPrecOpExpr. - AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc); - if (!highRes) - return nullptr; - - // If llhs is null, the product forms the first operand of the yet to be - // found expression. If non-null, the op to associate with llhs is llhsOp. - AffineExpr expr = - llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes; - - // Recurse for subsequent low prec op's after the affine high prec op - // expression. - if (AffineLowPrecOp nextOp = consumeIfLowPrecOp()) - return parseAffineLowPrecOpExpr(expr, nextOp); - return expr; - } - // Last operand in the expression list. - if (llhs) - return getAffineBinaryOpExpr(llhsOp, llhs, lhs); - // No llhs, 'lhs' itself is the expression. - return lhs; -} - -/// Parse an affine expression. -/// affine-expr ::= `(` affine-expr `)` -/// | `-` affine-expr -/// | affine-expr `+` affine-expr -/// | affine-expr `-` affine-expr -/// | affine-expr `*` affine-expr -/// | affine-expr `floordiv` affine-expr -/// | affine-expr `ceildiv` affine-expr -/// | affine-expr `mod` affine-expr -/// | bare-id -/// | integer-literal -/// -/// Additional conditions are checked depending on the production. For eg., -/// one of the operands for `*` has to be either constant/symbolic; the second -/// operand for floordiv, ceildiv, and mod has to be a positive integer. -AffineExpr AffineParser::parseAffineExpr() { - return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp); -} - -/// Parse a dim or symbol from the lists appearing before the actual -/// expressions of the affine map. Update our state to store the -/// dimensional/symbolic identifier. -ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) { - if (getToken().isNot(Token::bare_identifier)) - return emitError("expected bare identifier"); - - auto name = getTokenSpelling(); - for (auto entry : dimsAndSymbols) { - if (entry.first == name) - return emitError("redefinition of identifier '" + name + "'"); - } - consumeToken(Token::bare_identifier); - - dimsAndSymbols.push_back({name, idExpr}); - return success(); -} - -/// Parse the list of dimensional identifiers to an affine map. -ParseResult AffineParser::parseDimIdList(unsigned &numDims) { - if (parseToken(Token::l_paren, - "expected '(' at start of dimensional identifiers list")) { - return failure(); - } - - auto parseElt = [&]() -> ParseResult { - auto dimension = getAffineDimExpr(numDims++, getContext()); - return parseIdentifierDefinition(dimension); - }; - return parseCommaSeparatedListUntil(Token::r_paren, parseElt); -} - -/// Parse the list of symbolic identifiers to an affine map. -ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) { - consumeToken(Token::l_square); - auto parseElt = [&]() -> ParseResult { - auto symbol = getAffineSymbolExpr(numSymbols++, getContext()); - return parseIdentifierDefinition(symbol); - }; - return parseCommaSeparatedListUntil(Token::r_square, parseElt); -} - -/// Parse the list of symbolic identifiers to an affine map. -ParseResult -AffineParser::parseDimAndOptionalSymbolIdList(unsigned &numDims, - unsigned &numSymbols) { - if (parseDimIdList(numDims)) { - return failure(); - } - if (!getToken().is(Token::l_square)) { - numSymbols = 0; - return success(); - } - return parseSymbolIdList(numSymbols); -} - -/// Parses an ambiguous affine map or integer set definition inline. -ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map, - IntegerSet &set) { - unsigned numDims = 0, numSymbols = 0; - - // List of dimensional and optional symbol identifiers. - if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) { - return failure(); - } - - // This is needed for parsing attributes as we wouldn't know whether we would - // be parsing an integer set attribute or an affine map attribute. - bool isArrow = getToken().is(Token::arrow); - bool isColon = getToken().is(Token::colon); - if (!isArrow && !isColon) { - return emitError("expected '->' or ':'"); - } else if (isArrow) { - parseToken(Token::arrow, "expected '->' or '['"); - map = parseAffineMapRange(numDims, numSymbols); - return map ? success() : failure(); - } else if (parseToken(Token::colon, "expected ':' or '['")) { +/// Parse a comma separated list of elements that must have at least one entry +/// in it. +ParseResult Parser::parseCommaSeparatedList( + const std::function &parseElement) { + // Non-empty case starts with an element. + if (parseElement()) return failure(); - } - if ((set = parseIntegerSetConstraints(numDims, numSymbols))) - return success(); - - return failure(); -} - -/// Parse an AffineMap where the dim and symbol identifiers are SSA ids. -ParseResult -AffineParser::parseAffineMapOfSSAIds(AffineMap &map, - OpAsmParser::Delimiter delimiter) { - Token::Kind rightToken; - switch (delimiter) { - case OpAsmParser::Delimiter::Square: - if (parseToken(Token::l_square, "expected '['")) - return failure(); - rightToken = Token::r_square; - break; - case OpAsmParser::Delimiter::Paren: - if (parseToken(Token::l_paren, "expected '('")) + // Otherwise we have a list of comma separated elements. + while (consumeIf(Token::comma)) { + if (parseElement()) return failure(); - rightToken = Token::r_paren; - break; - default: - return emitError("unexpected delimiter"); } - - SmallVector exprs; - auto parseElt = [&]() -> ParseResult { - auto elt = parseAffineExpr(); - exprs.push_back(elt); - return elt ? success() : failure(); - }; - - // Parse a multi-dimensional affine expression (a comma-separated list of - // 1-d affine expressions); the list can be empty. Grammar: - // multi-dim-affine-expr ::= `(` `)` - // | `(` affine-expr (`,` affine-expr)* `)` - if (parseCommaSeparatedListUntil(rightToken, parseElt, - /*allowEmptyList=*/true)) - return failure(); - // Parsed a valid affine map. - map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands, - exprs, getContext()); return success(); } -/// Parse the range and sizes affine map definition inline. -/// -/// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr -/// -/// multi-dim-affine-expr ::= `(` `)` -/// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)` -AffineMap AffineParser::parseAffineMapRange(unsigned numDims, - unsigned numSymbols) { - parseToken(Token::l_paren, "expected '(' at start of affine map range"); - - SmallVector exprs; - auto parseElt = [&]() -> ParseResult { - auto elt = parseAffineExpr(); - ParseResult res = elt ? success() : failure(); - exprs.push_back(elt); - return res; - }; - - // Parse a multi-dimensional affine expression (a comma-separated list of - // 1-d affine expressions). Grammar: - // multi-dim-affine-expr ::= `(` `)` - // | `(` affine-expr (`,` affine-expr)* `)` - if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) - return AffineMap(); - - // Parsed a valid affine map. - return AffineMap::get(numDims, numSymbols, exprs, getContext()); -} - -/// Parse an affine constraint. -/// affine-constraint ::= affine-expr `>=` `0` -/// | affine-expr `==` `0` +/// Parse a comma-separated list of elements, terminated with an arbitrary +/// token. This allows empty lists if allowEmptyList is true. /// -/// isEq is set to true if the parsed constraint is an equality, false if it -/// is an inequality (greater than or equal). +/// abstract-list ::= rightToken // if allowEmptyList == true +/// abstract-list ::= element (',' element)* rightToken /// -AffineExpr AffineParser::parseAffineConstraint(bool *isEq) { - AffineExpr expr = parseAffineExpr(); - if (!expr) - return nullptr; - - if (consumeIf(Token::greater) && consumeIf(Token::equal) && - getToken().is(Token::integer)) { - auto dim = getToken().getUnsignedIntegerValue(); - if (dim.hasValue() && dim.getValue() == 0) { - consumeToken(Token::integer); - *isEq = false; - return expr; - } - return (emitError("expected '0' after '>='"), nullptr); +ParseResult Parser::parseCommaSeparatedListUntil( + Token::Kind rightToken, const std::function &parseElement, + bool allowEmptyList) { + // Handle the empty case. + if (getToken().is(rightToken)) { + if (!allowEmptyList) + return emitError("expected list element"); + consumeToken(rightToken); + return success(); } - if (consumeIf(Token::equal) && consumeIf(Token::equal) && - getToken().is(Token::integer)) { - auto dim = getToken().getUnsignedIntegerValue(); - if (dim.hasValue() && dim.getValue() == 0) { - consumeToken(Token::integer); - *isEq = true; - return expr; - } - return (emitError("expected '0' after '=='"), nullptr); - } + if (parseCommaSeparatedList(parseElement) || + parseToken(rightToken, "expected ',' or '" + + Token::getTokenSpelling(rightToken) + "'")) + return failure(); - return (emitError("expected '== 0' or '>= 0' at end of affine constraint"), - nullptr); + return success(); } -/// Parse the constraints that are part of an integer set definition. -/// integer-set-inline -/// ::= dim-and-symbol-id-lists `:` -/// '(' affine-constraint-conjunction? ')' -/// affine-constraint-conjunction ::= affine-constraint (`,` -/// affine-constraint)* -/// -IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, - unsigned numSymbols) { - if (parseToken(Token::l_paren, - "expected '(' at start of integer set constraint list")) - return IntegerSet(); - - SmallVector constraints; - SmallVector isEqs; - auto parseElt = [&]() -> ParseResult { - bool isEq; - auto elt = parseAffineConstraint(&isEq); - ParseResult res = elt ? success() : failure(); - if (elt) { - constraints.push_back(elt); - isEqs.push_back(isEq); - } - return res; - }; - - // Parse a list of affine constraints (comma-separated). - if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) - return IntegerSet(); - - // If no constraints were parsed, then treat this as a degenerate 'true' case. - if (constraints.empty()) { - /* 0 == 0 */ - auto zero = getAffineConstantExpr(0, getContext()); - return IntegerSet::get(numDims, numSymbols, zero, true); - } - - // Parsed a valid integer set. - return IntegerSet::get(numDims, numSymbols, constraints, isEqs); -} +InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) { + auto diag = mlir::emitError(getEncodedSourceLocation(loc), message); -/// Parse an ambiguous reference to either and affine map or an integer set. -ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map, - IntegerSet &set) { - return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set); -} -ParseResult Parser::parseAffineMapReference(AffineMap &map) { - llvm::SMLoc curLoc = getToken().getLoc(); - IntegerSet set; - if (parseAffineMapOrIntegerSetReference(map, set)) - return failure(); - if (set) - return emitError(curLoc, "expected AffineMap, but got IntegerSet"); - return success(); -} -ParseResult Parser::parseIntegerSetReference(IntegerSet &set) { - llvm::SMLoc curLoc = getToken().getLoc(); - AffineMap map; - if (parseAffineMapOrIntegerSetReference(map, set)) - return failure(); - if (map) - return emitError(curLoc, "expected IntegerSet, but got AffineMap"); - return success(); + // If we hit a parse error in response to a lexer error, then the lexer + // already reported the error. + if (getToken().is(Token::error)) + diag.abandon(); + return diag; } -/// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to -/// parse SSA value uses encountered while parsing affine expressions. -ParseResult -Parser::parseAffineMapOfSSAIds(AffineMap &map, - function_ref parseElement, - OpAsmParser::Delimiter delimiter) { - return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement) - .parseAffineMapOfSSAIds(map, delimiter); +/// Consume the specified token if present and return success. On failure, +/// output a diagnostic and return failure. +ParseResult Parser::parseToken(Token::Kind expectedToken, + const Twine &message) { + if (consumeIf(expectedToken)) + return success(); + return emitError(message); } //===----------------------------------------------------------------------===// @@ -3411,7 +187,7 @@ /// Parse an operation instance that is in the op-defined custom form. /// resultInfo specifies information about the "%name =" specifiers. - Operation *parseCustomOperation(ArrayRef resultInfo); + Operation *parseCustomOperation(ArrayRef resultIDs); //===--------------------------------------------------------------------===// // Region Parsing @@ -4298,7 +1074,7 @@ if (atToken.isNot(Token::at_identifier)) return failure(); - result = getBuilder().getStringAttr(extractSymbolReference(atToken)); + result = getBuilder().getStringAttr(atToken.getSymbolReference()); attrs.push_back(getBuilder().getNamedAttr(attrName, result)); parser.consumeToken(); return success(); @@ -5139,52 +1915,3 @@ sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); return parseSourceFile(sourceMgr, context); } - -/// Parses a symbol, of type 'T', and returns it if parsing was successful. If -/// parsing failed, nullptr is returned. The number of bytes read from the input -/// string is returned in 'numRead'. -template -static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead, - ParserFn &&parserFn) { - SymbolState aliasState; - return parseSymbol( - inputStr, context, aliasState, - [&](Parser &parser) { - SourceMgrDiagnosticHandler handler( - const_cast(parser.getSourceMgr()), - parser.getContext()); - return parserFn(parser); - }, - &numRead); -} - -Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) { - size_t numRead = 0; - return parseAttribute(attrStr, context, numRead); -} -Attribute mlir::parseAttribute(StringRef attrStr, Type type) { - size_t numRead = 0; - return parseAttribute(attrStr, type, numRead); -} - -Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context, - size_t &numRead) { - return parseSymbol(attrStr, context, numRead, [](Parser &parser) { - return parser.parseAttribute(); - }); -} -Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) { - return parseSymbol( - attrStr, type.getContext(), numRead, - [type](Parser &parser) { return parser.parseAttribute(type); }); -} - -Type mlir::parseType(StringRef typeStr, MLIRContext *context) { - size_t numRead = 0; - return parseType(typeStr, context, numRead); -} - -Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) { - return parseSymbol(typeStr, context, numRead, - [](Parser &parser) { return parser.parseType(); }); -} diff --git a/mlir/lib/Parser/ParserState.h b/mlir/lib/Parser/ParserState.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Parser/ParserState.h @@ -0,0 +1,85 @@ +//===- ParserState.h - MLIR ParserState -------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_LIB_PARSER_PARSERSTATE_H +#define MLIR_LIB_PARSER_PARSERSTATE_H + +#include "Lexer.h" +#include "mlir/IR/Attributes.h" +#include "llvm/ADT/StringMap.h" + +namespace mlir { +namespace detail { + +//===----------------------------------------------------------------------===// +// SymbolState +//===----------------------------------------------------------------------===// + +/// This class contains record of any parsed top-level symbols. +struct SymbolState { + // A map from attribute alias identifier to Attribute. + llvm::StringMap attributeAliasDefinitions; + + // A map from type alias identifier to Type. + llvm::StringMap typeAliasDefinitions; + + /// A set of locations into the main parser memory buffer for each of the + /// active nested parsers. Given that some nested parsers, i.e. custom dialect + /// parsers, operate on a temporary memory buffer, this provides an anchor + /// point for emitting diagnostics. + SmallVector nestedParserLocs; + + /// The top-level lexer that contains the original memory buffer provided by + /// the user. This is used by nested parsers to get a properly encoded source + /// location. + Lexer *topLevelLexer = nullptr; +}; + +//===----------------------------------------------------------------------===// +// ParserState +//===----------------------------------------------------------------------===// + +/// This class refers to all of the state maintained globally by the parser, +/// such as the current lexer position etc. +struct ParserState { + ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx, + SymbolState &symbols) + : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()), + symbols(symbols), parserDepth(symbols.nestedParserLocs.size()) { + // Set the top level lexer for the symbol state if one doesn't exist. + if (!symbols.topLevelLexer) + symbols.topLevelLexer = &lex; + } + ~ParserState() { + // Reset the top level lexer if it refers the lexer in our state. + if (symbols.topLevelLexer == &lex) + symbols.topLevelLexer = nullptr; + } + ParserState(const ParserState &) = delete; + void operator=(const ParserState &) = delete; + + /// The context we're parsing into. + MLIRContext *const context; + + /// The lexer for the source file we're parsing. + Lexer lex; + + /// This is the next token that hasn't been consumed yet. + Token curToken; + + /// The current state for symbol parsing. + SymbolState &symbols; + + /// The depth of this parser in the nested parsing stack. + size_t parserDepth; +}; + +} // end namespace detail +} // end namespace mlir + +#endif // MLIR_LIB_PARSER_PARSERSTATE_H diff --git a/mlir/lib/Parser/Token.h b/mlir/lib/Parser/Token.h --- a/mlir/lib/Parser/Token.h +++ b/mlir/lib/Parser/Token.h @@ -91,6 +91,10 @@ /// removing the quote characters and unescaping the contents of the string. std::string getStringValue() const; + /// Given a token containing a symbol reference, return the unescaped string + /// value. + std::string getSymbolReference() const; + // Location processing. llvm::SMLoc getLoc() const; llvm::SMLoc getEndLoc() const; diff --git a/mlir/lib/Parser/Token.cpp b/mlir/lib/Parser/Token.cpp --- a/mlir/lib/Parser/Token.cpp +++ b/mlir/lib/Parser/Token.cpp @@ -124,6 +124,18 @@ return result; } +/// Given a token containing a symbol reference, return the unescaped string +/// value. +std::string Token::getSymbolReference() const { + assert(is(Token::at_identifier) && "expected valid @-identifier"); + StringRef nameStr = getSpelling().drop_front(); + + // Check to see if the reference is a string literal, or a bare identifier. + if (nameStr.front() == '"') + return getStringValue(); + return std::string(nameStr); +} + /// Given a hash_identifier token like #123, try to parse the number out of /// the identifier, returning None if it is a named identifier like #x or /// if the integer doesn't fit. diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Parser/TypeParser.cpp @@ -0,0 +1,570 @@ +//===- TypeParser.cpp - MLIR Type Parser Implementation -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the parser for the MLIR Types. +// +//===----------------------------------------------------------------------===// + +#include "Parser.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/StandardTypes.h" + +using namespace mlir; +using namespace mlir::detail; + +/// Optionally parse a type. +OptionalParseResult Parser::parseOptionalType(Type &type) { + // There are many different starting tokens for a type, check them here. + switch (getToken().getKind()) { + case Token::l_paren: + case Token::kw_memref: + case Token::kw_tensor: + case Token::kw_complex: + case Token::kw_tuple: + case Token::kw_vector: + case Token::inttype: + case Token::kw_bf16: + case Token::kw_f16: + case Token::kw_f32: + case Token::kw_f64: + case Token::kw_index: + case Token::kw_none: + case Token::exclamation_identifier: + return failure(!(type = parseType())); + + default: + return llvm::None; + } +} + +/// Parse an arbitrary type. +/// +/// type ::= function-type +/// | non-function-type +/// +Type Parser::parseType() { + if (getToken().is(Token::l_paren)) + return parseFunctionType(); + return parseNonFunctionType(); +} + +/// Parse a function result type. +/// +/// function-result-type ::= type-list-parens +/// | non-function-type +/// +ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl &elements) { + if (getToken().is(Token::l_paren)) + return parseTypeListParens(elements); + + Type t = parseNonFunctionType(); + if (!t) + return failure(); + elements.push_back(t); + return success(); +} + +/// Parse a list of types without an enclosing parenthesis. The list must have +/// at least one member. +/// +/// type-list-no-parens ::= type (`,` type)* +/// +ParseResult Parser::parseTypeListNoParens(SmallVectorImpl &elements) { + auto parseElt = [&]() -> ParseResult { + auto elt = parseType(); + elements.push_back(elt); + return elt ? success() : failure(); + }; + + return parseCommaSeparatedList(parseElt); +} + +/// Parse a parenthesized list of types. +/// +/// type-list-parens ::= `(` `)` +/// | `(` type-list-no-parens `)` +/// +ParseResult Parser::parseTypeListParens(SmallVectorImpl &elements) { + if (parseToken(Token::l_paren, "expected '('")) + return failure(); + + // Handle empty lists. + if (getToken().is(Token::r_paren)) + return consumeToken(), success(); + + if (parseTypeListNoParens(elements) || + parseToken(Token::r_paren, "expected ')'")) + return failure(); + return success(); +} + +/// Parse a complex type. +/// +/// complex-type ::= `complex` `<` type `>` +/// +Type Parser::parseComplexType() { + consumeToken(Token::kw_complex); + + // Parse the '<'. + if (parseToken(Token::less, "expected '<' in complex type")) + return nullptr; + + llvm::SMLoc elementTypeLoc = getToken().getLoc(); + auto elementType = parseType(); + if (!elementType || + parseToken(Token::greater, "expected '>' in complex type")) + return nullptr; + if (!elementType.isa() && !elementType.isa()) + return emitError(elementTypeLoc, "invalid element type for complex"), + nullptr; + + return ComplexType::get(elementType); +} + +/// Parse a function type. +/// +/// function-type ::= type-list-parens `->` function-result-type +/// +Type Parser::parseFunctionType() { + assert(getToken().is(Token::l_paren)); + + SmallVector arguments, results; + if (parseTypeListParens(arguments) || + parseToken(Token::arrow, "expected '->' in function type") || + parseFunctionResultTypes(results)) + return nullptr; + + return builder.getFunctionType(arguments, results); +} + +/// Parse the offset and strides from a strided layout specification. +/// +/// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list +/// +ParseResult Parser::parseStridedLayout(int64_t &offset, + SmallVectorImpl &strides) { + // Parse offset. + consumeToken(Token::kw_offset); + if (!consumeIf(Token::colon)) + return emitError("expected colon after `offset` keyword"); + auto maybeOffset = getToken().getUnsignedIntegerValue(); + bool question = getToken().is(Token::question); + if (!maybeOffset && !question) + return emitError("invalid offset"); + offset = maybeOffset ? static_cast(maybeOffset.getValue()) + : MemRefType::getDynamicStrideOrOffset(); + consumeToken(); + + if (!consumeIf(Token::comma)) + return emitError("expected comma after offset value"); + + // Parse stride list. + if (!consumeIf(Token::kw_strides)) + return emitError("expected `strides` keyword after offset specification"); + if (!consumeIf(Token::colon)) + return emitError("expected colon after `strides` keyword"); + if (failed(parseStrideList(strides))) + return emitError("invalid braces-enclosed stride list"); + if (llvm::any_of(strides, [](int64_t st) { return st == 0; })) + return emitError("invalid memref stride"); + + return success(); +} + +/// Parse a memref type. +/// +/// memref-type ::= ranked-memref-type | unranked-memref-type +/// +/// ranked-memref-type ::= `memref` `<` dimension-list-ranked type +/// (`,` semi-affine-map-composition)? (`,` +/// memory-space)? `>` +/// +/// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>` +/// +/// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map +/// memory-space ::= integer-literal /* | TODO: address-space-id */ +/// +Type Parser::parseMemRefType() { + consumeToken(Token::kw_memref); + + if (parseToken(Token::less, "expected '<' in memref type")) + return nullptr; + + bool isUnranked; + SmallVector dimensions; + + if (consumeIf(Token::star)) { + // This is an unranked memref type. + isUnranked = true; + if (parseXInDimensionList()) + return nullptr; + + } else { + isUnranked = false; + if (parseDimensionListRanked(dimensions)) + return nullptr; + } + + // Parse the element type. + auto typeLoc = getToken().getLoc(); + auto elementType = parseType(); + if (!elementType) + return nullptr; + + // Check that memref is formed from allowed types. + if (!elementType.isIntOrFloat() && !elementType.isa() && + !elementType.isa()) + return emitError(typeLoc, "invalid memref element type"), nullptr; + + // Parse semi-affine-map-composition. + SmallVector affineMapComposition; + Optional memorySpace; + unsigned numDims = dimensions.size(); + + auto parseElt = [&]() -> ParseResult { + // Check for the memory space. + if (getToken().is(Token::integer)) { + if (memorySpace) + return emitError("multiple memory spaces specified in memref type"); + memorySpace = getToken().getUnsignedIntegerValue(); + if (!memorySpace.hasValue()) + return emitError("invalid memory space in memref type"); + consumeToken(Token::integer); + return success(); + } + if (isUnranked) + return emitError("cannot have affine map for unranked memref type"); + if (memorySpace) + return emitError("expected memory space to be last in memref type"); + + AffineMap map; + llvm::SMLoc mapLoc = getToken().getLoc(); + if (getToken().is(Token::kw_offset)) { + int64_t offset; + SmallVector strides; + if (failed(parseStridedLayout(offset, strides))) + return failure(); + // Construct strided affine map. + map = makeStridedLinearLayoutMap(strides, offset, state.context); + } else { + // Parse an affine map attribute. + auto affineMap = parseAttribute(); + if (!affineMap) + return failure(); + auto affineMapAttr = affineMap.dyn_cast(); + if (!affineMapAttr) + return emitError("expected affine map in memref type"); + map = affineMapAttr.getValue(); + } + + if (map.getNumDims() != numDims) { + size_t i = affineMapComposition.size(); + return emitError(mapLoc, "memref affine map dimension mismatch between ") + << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) + << " and affine map" << i + 1 << ": " << numDims + << " != " << map.getNumDims(); + } + numDims = map.getNumResults(); + affineMapComposition.push_back(map); + return success(); + }; + + // Parse a list of mappings and address space if present. + if (!consumeIf(Token::greater)) { + // Parse comma separated list of affine maps, followed by memory space. + if (parseToken(Token::comma, "expected ',' or '>' in memref type") || + parseCommaSeparatedListUntil(Token::greater, parseElt, + /*allowEmptyList=*/false)) { + return nullptr; + } + } + + if (isUnranked) + return UnrankedMemRefType::get(elementType, memorySpace.getValueOr(0)); + + return MemRefType::get(dimensions, elementType, affineMapComposition, + memorySpace.getValueOr(0)); +} + +/// Parse any type except the function type. +/// +/// non-function-type ::= integer-type +/// | index-type +/// | float-type +/// | extended-type +/// | vector-type +/// | tensor-type +/// | memref-type +/// | complex-type +/// | tuple-type +/// | none-type +/// +/// index-type ::= `index` +/// float-type ::= `f16` | `bf16` | `f32` | `f64` +/// none-type ::= `none` +/// +Type Parser::parseNonFunctionType() { + switch (getToken().getKind()) { + default: + return (emitError("expected non-function type"), nullptr); + case Token::kw_memref: + return parseMemRefType(); + case Token::kw_tensor: + return parseTensorType(); + case Token::kw_complex: + return parseComplexType(); + case Token::kw_tuple: + return parseTupleType(); + case Token::kw_vector: + return parseVectorType(); + // integer-type + case Token::inttype: { + auto width = getToken().getIntTypeBitwidth(); + if (!width.hasValue()) + return (emitError("invalid integer width"), nullptr); + if (width.getValue() > IntegerType::kMaxWidth) { + emitError(getToken().getLoc(), "integer bitwidth is limited to ") + << IntegerType::kMaxWidth << " bits"; + return nullptr; + } + + IntegerType::SignednessSemantics signSemantics = IntegerType::Signless; + if (Optional signedness = getToken().getIntTypeSignedness()) + signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned; + + auto loc = getEncodedSourceLocation(getToken().getLoc()); + consumeToken(Token::inttype); + return IntegerType::getChecked(width.getValue(), signSemantics, loc); + } + + // float-type + case Token::kw_bf16: + consumeToken(Token::kw_bf16); + return builder.getBF16Type(); + case Token::kw_f16: + consumeToken(Token::kw_f16); + return builder.getF16Type(); + case Token::kw_f32: + consumeToken(Token::kw_f32); + return builder.getF32Type(); + case Token::kw_f64: + consumeToken(Token::kw_f64); + return builder.getF64Type(); + + // index-type + case Token::kw_index: + consumeToken(Token::kw_index); + return builder.getIndexType(); + + // none-type + case Token::kw_none: + consumeToken(Token::kw_none); + return builder.getNoneType(); + + // extended type + case Token::exclamation_identifier: + return parseExtendedType(); + } +} + +/// Parse a tensor type. +/// +/// tensor-type ::= `tensor` `<` dimension-list type `>` +/// dimension-list ::= dimension-list-ranked | `*x` +/// +Type Parser::parseTensorType() { + consumeToken(Token::kw_tensor); + + if (parseToken(Token::less, "expected '<' in tensor type")) + return nullptr; + + bool isUnranked; + SmallVector dimensions; + + if (consumeIf(Token::star)) { + // This is an unranked tensor type. + isUnranked = true; + + if (parseXInDimensionList()) + return nullptr; + + } else { + isUnranked = false; + if (parseDimensionListRanked(dimensions)) + return nullptr; + } + + // Parse the element type. + auto elementTypeLoc = getToken().getLoc(); + auto elementType = parseType(); + if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) + return nullptr; + if (!TensorType::isValidElementType(elementType)) + return emitError(elementTypeLoc, "invalid tensor element type"), nullptr; + + if (isUnranked) + return UnrankedTensorType::get(elementType); + return RankedTensorType::get(dimensions, elementType); +} + +/// Parse a tuple type. +/// +/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` +/// +Type Parser::parseTupleType() { + consumeToken(Token::kw_tuple); + + // Parse the '<'. + if (parseToken(Token::less, "expected '<' in tuple type")) + return nullptr; + + // Check for an empty tuple by directly parsing '>'. + if (consumeIf(Token::greater)) + return TupleType::get(getContext()); + + // Parse the element types and the '>'. + SmallVector types; + if (parseTypeListNoParens(types) || + parseToken(Token::greater, "expected '>' in tuple type")) + return nullptr; + + return TupleType::get(types, getContext()); +} + +/// Parse a vector type. +/// +/// vector-type ::= `vector` `<` non-empty-static-dimension-list type `>` +/// non-empty-static-dimension-list ::= decimal-literal `x` +/// static-dimension-list +/// static-dimension-list ::= (decimal-literal `x`)* +/// +VectorType Parser::parseVectorType() { + consumeToken(Token::kw_vector); + + if (parseToken(Token::less, "expected '<' in vector type")) + return nullptr; + + SmallVector dimensions; + if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false)) + return nullptr; + if (dimensions.empty()) + return (emitError("expected dimension size in vector type"), nullptr); + if (any_of(dimensions, [](int64_t i) { return i <= 0; })) + return emitError(getToken().getLoc(), + "vector types must have positive constant sizes"), + nullptr; + + // Parse the element type. + auto typeLoc = getToken().getLoc(); + auto elementType = parseType(); + if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) + return nullptr; + if (!VectorType::isValidElementType(elementType)) + return emitError(typeLoc, "vector elements must be int or float type"), + nullptr; + + return VectorType::get(dimensions, elementType); +} + +/// Parse a dimension list of a tensor or memref type. This populates the +/// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and +/// errors out on `?` otherwise. +/// +/// dimension-list-ranked ::= (dimension `x`)* +/// dimension ::= `?` | decimal-literal +/// +/// When `allowDynamic` is not set, this is used to parse: +/// +/// static-dimension-list ::= (decimal-literal `x`)* +ParseResult +Parser::parseDimensionListRanked(SmallVectorImpl &dimensions, + bool allowDynamic) { + while (getToken().isAny(Token::integer, Token::question)) { + if (consumeIf(Token::question)) { + if (!allowDynamic) + return emitError("expected static shape"); + dimensions.push_back(-1); + } else { + // Hexadecimal integer literals (starting with `0x`) are not allowed in + // aggregate type declarations. Therefore, `0xf32` should be processed as + // a sequence of separate elements `0`, `x`, `f32`. + if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { + // We can get here only if the token is an integer literal. Hexadecimal + // integer literals can only start with `0x` (`1x` wouldn't lex as a + // literal, just `1` would, at which point we don't get into this + // branch). + assert(getTokenSpelling()[0] == '0' && "invalid integer literal"); + dimensions.push_back(0); + state.lex.resetPointer(getTokenSpelling().data() + 1); + consumeToken(); + } else { + // Make sure this integer value is in bound and valid. + auto dimension = getToken().getUnsignedIntegerValue(); + if (!dimension.hasValue()) + return emitError("invalid dimension"); + dimensions.push_back((int64_t)dimension.getValue()); + consumeToken(Token::integer); + } + } + + // Make sure we have an 'x' or something like 'xbf32'. + if (parseXInDimensionList()) + return failure(); + } + + return success(); +} + +/// Parse an 'x' token in a dimension list, handling the case where the x is +/// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next +/// token. +ParseResult Parser::parseXInDimensionList() { + if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x') + return emitError("expected 'x' in dimension list"); + + // If we had a prefix of 'x', lex the next token immediately after the 'x'. + if (getTokenSpelling().size() != 1) + state.lex.resetPointer(getTokenSpelling().data() + 1); + + // Consume the 'x'. + consumeToken(Token::bare_identifier); + + return success(); +} + +// Parse a comma-separated list of dimensions, possibly empty: +// stride-list ::= `[` (dimension (`,` dimension)*)? `]` +ParseResult Parser::parseStrideList(SmallVectorImpl &dimensions) { + if (!consumeIf(Token::l_square)) + return failure(); + // Empty list early exit. + if (consumeIf(Token::r_square)) + return success(); + while (true) { + if (consumeIf(Token::question)) { + dimensions.push_back(MemRefType::getDynamicStrideOrOffset()); + } else { + // This must be an integer value. + int64_t val; + if (getToken().getSpelling().getAsInteger(10, val)) + return emitError("invalid integer value: ") << getToken().getSpelling(); + // Make sure it is not the one value for `?`. + if (ShapedType::isDynamic(val)) + return emitError("invalid integer value: ") + << getToken().getSpelling() + << ", use `?` to specify a dynamic dimension"; + dimensions.push_back(val); + consumeToken(Token::integer); + } + if (!consumeIf(Token::comma)) + break; + } + if (!consumeIf(Token::r_square)) + return failure(); + return success(); +}