diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -969,13 +969,16 @@ if (parseToken(Token::less, "expected '<' in complex type")) return nullptr; - auto typeLocation = getEncodedSourceLocation(getToken().getLoc()); + llvm::SMLoc elementTypeLoc = getToken().getLoc(); auto elementType = parseType(); if (!elementType || parseToken(Token::greater, "expected '>' in complex type")) return nullptr; + if (!elementType.isa() && !elementType.isa()) + return emitError(elementTypeLoc, "invalid element type for complex"), + nullptr; - return ComplexType::getChecked(elementType, typeLocation); + return ComplexType::get(elementType); } /// Parse an extended type. @@ -1096,69 +1099,79 @@ if (!elementType) return nullptr; + // Check that memref is formed from allowed types. + if (!elementType.isIntOrFloat() && !elementType.isa() && + !elementType.isa()) + return emitError(typeLoc, "invalid memref element type"), nullptr; + // Parse semi-affine-map-composition. SmallVector affineMapComposition; - unsigned memorySpace = 0; - bool parsedMemorySpace = false; + Optional memorySpace; + unsigned numDims = dimensions.size(); auto parseElt = [&]() -> ParseResult { + // Check for the memory space. if (getToken().is(Token::integer)) { - // Parse memory space. - if (parsedMemorySpace) + if (memorySpace) return emitError("multiple memory spaces specified in memref type"); - auto v = getToken().getUnsignedIntegerValue(); - if (!v.hasValue()) + memorySpace = getToken().getUnsignedIntegerValue(); + if (!memorySpace.hasValue()) return emitError("invalid memory space in memref type"); - memorySpace = v.getValue(); consumeToken(Token::integer); - parsedMemorySpace = true; + return success(); + } + if (isUnranked) + return emitError("cannot have affine map for unranked memref type"); + if (memorySpace) + return emitError("expected memory space to be last in memref type"); + + AffineMap map; + llvm::SMLoc mapLoc = getToken().getLoc(); + if (getToken().is(Token::kw_offset)) { + int64_t offset; + SmallVector strides; + if (failed(parseStridedLayout(offset, strides))) + return failure(); + // Construct strided affine map. + map = makeStridedLinearLayoutMap(strides, offset, state.context); } else { - if (isUnranked) - return emitError("cannot have affine map for unranked memref type"); - if (parsedMemorySpace) - return emitError("expected memory space to be last in memref type"); - if (getToken().is(Token::kw_offset)) { - int64_t offset; - SmallVector strides; - if (failed(parseStridedLayout(offset, strides))) - return failure(); - // Construct strided affine map. - auto map = makeStridedLinearLayoutMap(strides, offset, - elementType.getContext()); - affineMapComposition.push_back(map); - } else { - // Parse affine map. - auto affineMap = parseAttribute(); - if (!affineMap) - return failure(); - // Verify that the parsed attribute is an affine map. - if (auto affineMapAttr = affineMap.dyn_cast()) - affineMapComposition.push_back(affineMapAttr.getValue()); - else - return emitError("expected affine map in memref type"); - } + // Parse an affine map attribute. + auto affineMap = parseAttribute(); + if (!affineMap) + return failure(); + auto affineMapAttr = affineMap.dyn_cast(); + if (!affineMapAttr) + return emitError("expected affine map in memref type"); + map = affineMapAttr.getValue(); + } + + if (map.getNumDims() != numDims) { + size_t i = affineMapComposition.size(); + return emitError(mapLoc, "memref affine map dimension mismatch between ") + << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) + << " and affine map" << i + 1 << ": " << numDims + << " != " << map.getNumDims(); } + numDims = map.getNumResults(); + affineMapComposition.push_back(map); return success(); }; // Parse a list of mappings and address space if present. - if (consumeIf(Token::comma)) { + if (!consumeIf(Token::greater)) { // Parse comma separated list of affine maps, followed by memory space. - if (parseCommaSeparatedListUntil(Token::greater, parseElt, + if (parseToken(Token::comma, "expected ',' or '>' in memref type") || + parseCommaSeparatedListUntil(Token::greater, parseElt, /*allowEmptyList=*/false)) { return nullptr; } - } else { - if (parseToken(Token::greater, "expected ',' or '>' in memref type")) - return nullptr; } if (isUnranked) - return UnrankedMemRefType::getChecked(elementType, memorySpace, - getEncodedSourceLocation(typeLoc)); + return UnrankedMemRefType::get(elementType, memorySpace.getValueOr(0)); - return MemRefType::getChecked(dimensions, elementType, affineMapComposition, - memorySpace, getEncodedSourceLocation(typeLoc)); + return MemRefType::get(dimensions, elementType, affineMapComposition, + memorySpace.getValueOr(0)); } /// Parse any type except the function type. @@ -1197,9 +1210,14 @@ auto width = getToken().getIntTypeBitwidth(); if (!width.hasValue()) return (emitError("invalid integer width"), nullptr); - auto loc = getEncodedSourceLocation(getToken().getLoc()); + if (width.getValue() > IntegerType::kMaxWidth) { + emitError(getToken().getLoc(), "integer bitwidth is limited to ") + << IntegerType::kMaxWidth << " bits"; + return nullptr; + } + consumeToken(Token::inttype); - return IntegerType::getChecked(width.getValue(), builder.getContext(), loc); + return IntegerType::get(width.getValue(), builder.getContext()); } // float-type @@ -1260,14 +1278,16 @@ } // Parse the element type. - auto typeLocation = getEncodedSourceLocation(getToken().getLoc()); + auto elementTypeLoc = getToken().getLoc(); auto elementType = parseType(); if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) return nullptr; + if (!TensorType::isValidElementType(elementType)) + return emitError(elementTypeLoc, "invalid tensor element type"), nullptr; if (isUnranked) - return UnrankedTensorType::getChecked(elementType, typeLocation); - return RankedTensorType::getChecked(dimensions, elementType, typeLocation); + return UnrankedTensorType::get(elementType); + return RankedTensorType::get(dimensions, elementType); } /// Parse a tuple type. @@ -1312,15 +1332,21 @@ return nullptr; if (dimensions.empty()) return (emitError("expected dimension size in vector type"), nullptr); + if (any_of(dimensions, [](int64_t i) { return i <= 0; })) + return emitError(getToken().getLoc(), + "vector types must have positive constant sizes"), + nullptr; // Parse the element type. auto typeLoc = getToken().getLoc(); auto elementType = parseType(); if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) return nullptr; + if (!VectorType::isValidElementType(elementType)) + return emitError(typeLoc, "vector elements must be int or float type"), + nullptr; - return VectorType::getChecked(dimensions, elementType, - getEncodedSourceLocation(typeLoc)); + return VectorType::get(dimensions, elementType); } /// Parse a dimension list of a tensor or memref type. This populates the