diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -55,47 +55,42 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { Builder &builder = parser.getBuilder(); - Attribute predicate; - SmallVector attrs; + StringAttr predicateAttr; OpAsmParser::OperandType lhs, rhs; Type type; llvm::SMLoc predicateLoc, trailingTypeLoc; if (parser.getCurrentLocation(&predicateLoc) || - parser.parseAttribute(predicate, "predicate", attrs) || + parser.parseAttribute(predicateAttr, "predicate", result.attributes) || parser.parseOperand(lhs) || parser.parseComma() || - parser.parseOperand(rhs) || parser.parseOptionalAttrDict(attrs) || - parser.parseColon() || parser.getCurrentLocation(&trailingTypeLoc) || - parser.parseType(type) || + parser.parseOperand(rhs) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || + parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || parser.resolveOperand(lhs, type, result.operands) || parser.resolveOperand(rhs, type, result.operands)) return failure(); // Replace the string attribute `predicate` with an integer attribute. - auto predicateStr = predicate.dyn_cast(); - if (!predicateStr) - return parser.emitError(predicateLoc, - "expected 'predicate' attribute of string type"); - int64_t predicateValue = 0; if (std::is_same()) { Optional predicate = - symbolizeICmpPredicate(predicateStr.getValue()); + symbolizeICmpPredicate(predicateAttr.getValue()); if (!predicate) return parser.emitError(predicateLoc) - << "'" << predicateStr.getValue() + << "'" << predicateAttr.getValue() << "' is an incorrect value of the 'predicate' attribute"; predicateValue = static_cast(predicate.getValue()); } else { Optional predicate = - symbolizeFCmpPredicate(predicateStr.getValue()); + symbolizeFCmpPredicate(predicateAttr.getValue()); if (!predicate) return parser.emitError(predicateLoc) - << "'" << predicateStr.getValue() + << "'" << predicateAttr.getValue() << "' is an incorrect value of the 'predicate' attribute"; predicateValue = static_cast(predicate.getValue()); } - attrs[0].second = parser.getBuilder().getI64IntegerAttr(predicateValue); + result.attributes[0].second = + parser.getBuilder().getI64IntegerAttr(predicateValue); // The result type is either i1 or a vector type if the inputs are // vectors. @@ -108,7 +103,6 @@ resultType = LLVMType::getVectorTy( resultType, argType.getUnderlyingType()->getVectorNumElements()); - result.attributes = attrs; result.addTypes({resultType}); return success(); } @@ -134,14 +128,13 @@ // ::= `llvm.alloca` ssa-use `x` type attribute-dict? // `:` type `,` type static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) { - SmallVector attrs; OpAsmParser::OperandType arraySize; Type type, elemType; llvm::SMLoc trailingTypeLoc; if (parser.parseOperand(arraySize) || parser.parseKeyword("x") || - parser.parseType(elemType) || parser.parseOptionalAttrDict(attrs) || - parser.parseColon() || parser.getCurrentLocation(&trailingTypeLoc) || - parser.parseType(type)) + parser.parseType(elemType) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || + parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) return failure(); // Extract the result type from the trailing function type. @@ -155,7 +148,6 @@ if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands)) return failure(); - result.attributes = attrs; result.addTypes({funcType.getResult(0)}); return success(); } @@ -177,14 +169,13 @@ // ::= `llvm.getelementptr` ssa-use `[` ssa-use-list `]` // attribute-dict? `:` type static ParseResult parseGEPOp(OpAsmParser &parser, OperationState &result) { - SmallVector attrs; OpAsmParser::OperandType base; SmallVector indices; Type type; llvm::SMLoc trailingTypeLoc; if (parser.parseOperand(base) || parser.parseOperandList(indices, OpAsmParser::Delimiter::Square) || - parser.parseOptionalAttrDict(attrs) || parser.parseColon() || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) return failure(); @@ -202,7 +193,6 @@ parser.getNameLoc(), result.operands)) return failure(); - result.attributes = attrs; result.addTypes(funcType.getResults()); return success(); } @@ -233,20 +223,18 @@ // ::= `llvm.load` ssa-use attribute-dict? `:` type static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { - SmallVector attrs; OpAsmParser::OperandType addr; Type type; llvm::SMLoc trailingTypeLoc; - if (parser.parseOperand(addr) || parser.parseOptionalAttrDict(attrs) || - parser.parseColon() || parser.getCurrentLocation(&trailingTypeLoc) || - parser.parseType(type) || + if (parser.parseOperand(addr) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || + parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || parser.resolveOperand(addr, type, result.operands)) return failure(); Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); - result.attributes = attrs; result.addTypes(elemTy); return success(); } @@ -263,15 +251,14 @@ // ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { - SmallVector attrs; OpAsmParser::OperandType addr, value; Type type; llvm::SMLoc trailingTypeLoc; if (parser.parseOperand(value) || parser.parseComma() || - parser.parseOperand(addr) || parser.parseOptionalAttrDict(attrs) || - parser.parseColon() || parser.getCurrentLocation(&trailingTypeLoc) || - parser.parseType(type)) + parser.parseOperand(addr) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || + parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) return failure(); Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); @@ -282,7 +269,6 @@ parser.resolveOperand(addr, type, result.operands)) return failure(); - result.attributes = attrs; return success(); } @@ -316,7 +302,6 @@ // ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` // attribute-dict? `:` function-type static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) { - SmallVector attrs; SmallVector operands; Type type; SymbolRefAttr funcAttr; @@ -332,11 +317,11 @@ // Optionally parse a function identifier. if (isDirect) - if (parser.parseAttribute(funcAttr, "callee", attrs)) + if (parser.parseAttribute(funcAttr, "callee", result.attributes)) return failure(); if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || - parser.parseOptionalAttrDict(attrs) || parser.parseColon() || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) return failure(); @@ -396,7 +381,6 @@ result.addTypes(llvmResultType); } - result.attributes = attrs; return success(); } @@ -461,23 +445,18 @@ // resulting type wrapped in MLIR, or nullptr on error. static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser &parser, Type containerType, - Attribute positionAttr, + ArrayAttr positionAttr, llvm::SMLoc attributeLoc, llvm::SMLoc typeLoc) { auto wrappedContainerType = containerType.dyn_cast(); if (!wrappedContainerType) return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr; - auto positionArrayAttr = positionAttr.dyn_cast(); - if (!positionArrayAttr) - return parser.emitError(attributeLoc, "expected an array attribute"), - nullptr; - // Infer the element type from the structure type: iteratively step inside the // type by taking the element type, indexed by the position attribute for // structures. Check the position index before accessing, it is supposed to // be in bounds. - for (Attribute subAttr : positionArrayAttr) { + for (Attribute subAttr : positionAttr) { auto positionElementAttr = subAttr.dyn_cast(); if (!positionElementAttr) return parser.emitError(attributeLoc, @@ -512,16 +491,15 @@ // attribute-dict? `:` type static ParseResult parseExtractValueOp(OpAsmParser &parser, OperationState &result) { - SmallVector attrs; OpAsmParser::OperandType container; Type containerType; - Attribute positionAttr; + ArrayAttr positionAttr; llvm::SMLoc attributeLoc, trailingTypeLoc; if (parser.parseOperand(container) || parser.getCurrentLocation(&attributeLoc) || - parser.parseAttribute(positionAttr, "position", attrs) || - parser.parseOptionalAttrDict(attrs) || parser.parseColon() || + parser.parseAttribute(positionAttr, "position", result.attributes) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(containerType) || parser.resolveOperand(container, containerType, result.operands)) @@ -532,7 +510,6 @@ if (!elementType) return failure(); - result.attributes = attrs; result.addTypes(elementType); return success(); } @@ -599,7 +576,7 @@ OperationState &result) { OpAsmParser::OperandType container, value; Type containerType; - Attribute positionAttr; + ArrayAttr positionAttr; llvm::SMLoc attributeLoc, trailingTypeLoc; if (parser.parseOperand(value) || parser.parseComma() || @@ -1080,15 +1057,15 @@ static ParseResult parseShuffleVectorOp(OpAsmParser &parser, OperationState &result) { llvm::SMLoc loc; - SmallVector attrs; OpAsmParser::OperandType v1, v2; - Attribute maskAttr; + ArrayAttr maskAttr; Type typeV1, typeV2; if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) || parser.parseComma() || parser.parseOperand(v2) || - parser.parseAttribute(maskAttr, "mask", attrs) || - parser.parseOptionalAttrDict(attrs) || parser.parseColonType(typeV1) || - parser.parseComma() || parser.parseType(typeV2) || + parser.parseAttribute(maskAttr, "mask", result.attributes) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(typeV1) || parser.parseComma() || + parser.parseType(typeV2) || parser.resolveOperand(v1, typeV1, result.operands) || parser.resolveOperand(v2, typeV2, result.operands)) return failure(); @@ -1097,10 +1074,8 @@ !wrappedContainerType1.getUnderlyingType()->isVectorTy()) return parser.emitError( loc, "expected LLVM IR dialect vector type for operand #1"); - auto vType = - LLVMType::getVectorTy(wrappedContainerType1.getVectorElementType(), - maskAttr.cast().size()); - result.attributes = attrs; + auto vType = LLVMType::getVectorTy( + wrappedContainerType1.getVectorElementType(), maskAttr.size()); result.addTypes(vType); return success(); } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -12,7 +12,7 @@ // ----- func @icmp_non_string(%arg0 : !llvm.i32, %arg1 : !llvm<"i16">) { - // expected-error@+1 {{expected 'predicate' attribute of string type}} + // expected-error@+1 {{invalid kind of attribute specified}} llvm.icmp 42 %arg0, %arg0 : !llvm.i32 return } @@ -156,7 +156,7 @@ func @insertvalue_non_array_position() { // Note the double-type, otherwise attribute parsing consumes the trailing // type of the op as the (wrong) attribute type. - // expected-error@+1 {{expected an array attribute}} + // expected-error@+1 {{invalid kind of attribute specified}} llvm.insertvalue %a, %b 0 : i32 : !llvm<"{i32}"> } @@ -200,7 +200,7 @@ func @extractvalue_non_array_position() { // Note the double-type, otherwise attribute parsing consumes the trailing // type of the op as the (wrong) attribute type. - // expected-error@+1 {{expected an array attribute}} + // expected-error@+1 {{invalid kind of attribute specified}} llvm.extractvalue %b 0 : i32 : !llvm<"{i32}"> }