diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -357,6 +357,8 @@ let verifier = ?; let hasCanonicalizer = 1; + + let assemblyFormat = "$callee `(` $operands `)` attr-dict `:` type($callee)"; } def CeilFOp : FloatUnaryOp<"ceilf"> { @@ -490,6 +492,8 @@ let verifier = [{ return success(); }]; let hasFolder = 1; + + let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; } def CondBranchOp : Std_Op<"cond_br", [Terminator]> { @@ -761,6 +765,10 @@ }]; let hasFolder = 1; + + let assemblyFormat = [{ + $aggregate `[` $indices `]` attr-dict `:` type($aggregate) + }]; } def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> { @@ -853,6 +861,8 @@ }]; let hasFolder = 1; + + let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)"; } def LogOp : FloatUnaryOp<"log"> { @@ -1090,6 +1100,10 @@ }]; let hasFolder = 1; + + let assemblyFormat = [{ + $condition `,` $true_value `,` $false_value attr-dict `:` type($result) + }]; } def SignExtendIOp : Std_Op<"sexti", @@ -1222,6 +1236,8 @@ [{ build(builder, result, aggregateType, element); }]>]; let hasFolder = 1; + + let assemblyFormat = "$input attr-dict `:` type($aggregate)"; } def StoreOp : Std_Op<"store", @@ -1264,6 +1280,10 @@ }]; let hasFolder = 1; + + let assemblyFormat = [{ + $value `,` $memref `[` $indices `]` attr-dict `:` type($memref) + }]; } def SubFOp : FloatArithmeticOp<"subf"> { @@ -1517,11 +1537,12 @@ result.addTypes(resultType); }]>]; - let extraClassDeclaration = [{ /// The result of a tensor_load is always a tensor. TensorType getType() { return getResult().getType().cast(); } }]; + + let assemblyFormat = "$memref attr-dict `:` type($memref)"; } def TensorStoreOp : Std_Op<"tensor_store", @@ -1545,6 +1566,8 @@ let arguments = (ins AnyTensor:$tensor, AnyMemRef:$memref); // TensorStoreOp is fully verified by traits. let verifier = ?; + + let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)"; } def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> { diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -363,6 +363,10 @@ return vector().getType().cast(); } }]; + + let assemblyFormat = [{ + $vector `[` $position `:` type($position) `]` attr-dict `:` type($vector) + }]; } def Vector_ExtractOp : @@ -512,6 +516,11 @@ return dest().getType().cast(); } }]; + + let assemblyFormat = [{ + $source `,` $dest `[` $position `:` type($position) `]` attr-dict `:` + type($result) + }]; } def Vector_InsertOp : diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -496,6 +496,12 @@ return failure(); return success(); } + template + ParseResult resolveOperands(Operands &&operands, Type type, llvm::SMLoc loc, + SmallVectorImpl &result) { + return resolveOperands(std::forward(operands), + ArrayRef(type), loc, result); + } template ParseResult resolveOperands(Operands &&operands, Types &&types, llvm::SMLoc loc, SmallVectorImpl &result) { diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -294,6 +294,11 @@ void addTypes(ArrayRef newTypes) { types.append(newTypes.begin(), newTypes.end()); } + template + std::enable_if_t>::value> + addTypes(RangeT &&newTypes) { + types.append(newTypes.begin(), newTypes.end()); + } /// Add an attribute with the specified name. void addAttribute(StringRef name, Attribute attr) { diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -505,29 +505,6 @@ }; } // end anonymous namespace. -static ParseResult parseCallIndirectOp(OpAsmParser &parser, - OperationState &result) { - FunctionType calleeType; - OpAsmParser::OperandType callee; - llvm::SMLoc operandsLoc; - SmallVector operands; - return failure( - parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) || - parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(calleeType) || - parser.resolveOperand(callee, calleeType, result.operands) || - parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc, - result.operands) || - parser.addTypesToList(calleeType.getResults(), result.types)); -} - -static void print(OpAsmPrinter &p, CallIndirectOp op) { - p << "call_indirect " << op.getCallee() << '(' << op.getArgOperands() << ')'; - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); - p << " : " << op.getCallee().getType(); -} - void CallIndirectOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); @@ -570,55 +547,6 @@ build->getI64IntegerAttr(static_cast(predicate))); } -static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) { - SmallVector ops; - SmallVector attrs; - Attribute predicateNameAttr; - Type type; - if (parser.parseAttribute(predicateNameAttr, CmpIOp::getPredicateAttrName(), - attrs) || - parser.parseComma() || parser.parseOperandList(ops, 2) || - parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) || - parser.resolveOperands(ops, type, result.operands)) - return failure(); - - if (!predicateNameAttr.isa()) - return parser.emitError(parser.getNameLoc(), - "expected string comparison predicate attribute"); - - // Rewrite string attribute to an enum value. - StringRef predicateName = predicateNameAttr.cast().getValue(); - Optional predicate = symbolizeCmpIPredicate(predicateName); - if (!predicate.hasValue()) - return parser.emitError(parser.getNameLoc()) - << "unknown comparison predicate \"" << predicateName << "\""; - - auto builder = parser.getBuilder(); - Type i1Type = getCheckedI1SameShape(type); - if (!i1Type) - return parser.emitError(parser.getNameLoc(), - "expected type with valid i1 shape"); - - attrs[0].second = builder.getI64IntegerAttr(static_cast(*predicate)); - result.attributes = attrs; - - result.addTypes({i1Type}); - return success(); -} - -static void print(OpAsmPrinter &p, CmpIOp op) { - p << "cmpi "; - - Builder b(op.getContext()); - auto predicateValue = - op.getAttrOfType(CmpIOp::getPredicateAttrName()).getInt(); - p << '"' << stringifyCmpIPredicate(static_cast(predicateValue)) - << '"' << ", " << op.lhs() << ", " << op.rhs(); - p.printOptionalAttrDict(op.getAttrs(), - /*elidedAttrs=*/{CmpIOp::getPredicateAttrName()}); - p << " : " << op.lhs().getType(); -} - // Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer // comparison predicates. static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, @@ -1486,30 +1414,6 @@ // ExtractElementOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, ExtractElementOp op) { - p << "extract_element " << op.getAggregate() << '[' << op.getIndices(); - p << ']'; - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.getAggregate().getType(); -} - -static ParseResult parseExtractElementOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType aggregateInfo; - SmallVector indexInfo; - ShapedType type; - - auto indexTy = parser.getBuilder().getIndexType(); - return failure( - parser.parseOperand(aggregateInfo) || - parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type) || - parser.resolveOperand(aggregateInfo, type, result.operands) || - parser.resolveOperands(indexInfo, indexTy, result.operands) || - parser.addTypeToList(type.getElementType(), result.types)); -} - static LogicalResult verify(ExtractElementOp op) { // Verify the # indices match if we have a ranked type. auto aggregateType = op.getAggregate().getType().cast(); @@ -1577,28 +1481,6 @@ // LoadOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, LoadOp op) { - p << "load " << op.getMemRef() << '[' << op.getIndices() << ']'; - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.getMemRefType(); -} - -static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType memrefInfo; - SmallVector indexInfo; - MemRefType type; - - auto indexTy = parser.getBuilder().getIndexType(); - return failure( - parser.parseOperand(memrefInfo) || - parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type) || - parser.resolveOperand(memrefInfo, type, result.operands) || - parser.resolveOperands(indexInfo, indexTy, result.operands) || - parser.addTypeToList(type.getElementType(), result.types)); -} - static LogicalResult verify(LoadOp op) { if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) return op.emitOpError("incorrect number of indices for load"); @@ -1902,31 +1784,6 @@ // SelectOp //===----------------------------------------------------------------------===// -static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) { - SmallVector ops; - SmallVector attrs; - Type type; - if (parser.parseOperandList(ops, 3) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type)) - return failure(); - - auto i1Type = getCheckedI1SameShape(type); - if (!i1Type) - return parser.emitError(parser.getNameLoc(), - "expected type with valid i1 shape"); - - std::array types = {i1Type, type, type}; - return failure(parser.resolveOperands(ops, types, parser.getNameLoc(), - result.operands) || - parser.addTypeToList(type, result.types)); -} - -static void print(OpAsmPrinter &p, SelectOp op) { - p << "select " << op.getOperands() << " : " << op.getTrueValue().getType(); - p.printOptionalAttrDict(op.getAttrs()); -} - OpFoldResult SelectOp::fold(ArrayRef operands) { auto condition = getCondition(); @@ -1968,25 +1825,6 @@ // SplatOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, SplatOp op) { - p << "splat " << op.getOperand(); - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.getType(); -} - -static ParseResult parseSplatOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType splatValueInfo; - ShapedType shapedType; - - return failure(parser.parseOperand(splatValueInfo) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(shapedType) || - parser.resolveOperand(splatValueInfo, - shapedType.getElementType(), - result.operands) || - parser.addTypeToList(shapedType, result.types)); -} - static LogicalResult verify(SplatOp op) { // TODO: we could replace this by a trait. if (op.getOperand().getType() != @@ -2017,32 +1855,6 @@ // StoreOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, StoreOp op) { - p << "store " << op.getValueToStore(); - p << ", " << op.getMemRef() << '[' << op.getIndices() << ']'; - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.getMemRefType(); -} - -static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType storeValueInfo; - OpAsmParser::OperandType memrefInfo; - SmallVector indexInfo; - MemRefType memrefType; - - auto indexTy = parser.getBuilder().getIndexType(); - return failure( - parser.parseOperand(storeValueInfo) || parser.parseComma() || - parser.parseOperand(memrefInfo) || - parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(memrefType) || - parser.resolveOperand(storeValueInfo, memrefType.getElementType(), - result.operands) || - parser.resolveOperand(memrefInfo, memrefType, result.operands) || - parser.resolveOperands(indexInfo, indexTy, result.operands)); -} - static LogicalResult verify(StoreOp op) { if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) return op.emitOpError("store index operand count not equal to memref rank"); @@ -2157,51 +1969,6 @@ } //===----------------------------------------------------------------------===// -// TensorLoadOp -//===----------------------------------------------------------------------===// - -static void print(OpAsmPrinter &p, TensorLoadOp op) { - p << "tensor_load " << op.getOperand(); - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.getOperand().getType(); -} - -static ParseResult parseTensorLoadOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType op; - Type type; - return failure( - parser.parseOperand(op) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type) || - parser.resolveOperand(op, type, result.operands) || - parser.addTypeToList(getTensorTypeFromMemRefType(type), result.types)); -} - -//===----------------------------------------------------------------------===// -// TensorStoreOp -//===----------------------------------------------------------------------===// - -static void print(OpAsmPrinter &p, TensorStoreOp op) { - p << "tensor_store " << op.tensor() << ", " << op.memref(); - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.memref().getType(); -} - -static ParseResult parseTensorStoreOp(OpAsmParser &parser, - OperationState &result) { - SmallVector ops; - Type type; - llvm::SMLoc loc = parser.getCurrentLocation(); - return failure( - parser.parseOperandList(ops, /*requiredOperandCount=*/2) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type) || - parser.resolveOperands(ops, {getTensorTypeFromMemRefType(type), type}, - loc, result.operands)); -} - -//===----------------------------------------------------------------------===// // TruncateIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -412,31 +412,6 @@ // ExtractElementOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, vector::ExtractElementOp op) { - p << op.getOperationName() << " " << op.vector() << "[" << op.position() - << " : " << op.position().getType() << "]"; - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.vector().getType(); -} - -static ParseResult parseExtractElementOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType vector, position; - Type positionType; - VectorType vectorType; - if (parser.parseOperand(vector) || parser.parseLSquare() || - parser.parseOperand(position) || parser.parseColonType(positionType) || - parser.parseRSquare() || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(vectorType)) - return failure(); - Type resultType = vectorType.getElementType(); - return failure( - parser.resolveOperand(vector, vectorType, result.operands) || - parser.resolveOperand(position, positionType, result.operands) || - parser.addTypeToList(resultType, result.types)); -} - static LogicalResult verify(vector::ExtractElementOp op) { VectorType vectorType = op.getVectorType(); if (vectorType.getRank() != 1) @@ -715,33 +690,6 @@ // InsertElementOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, InsertElementOp op) { - p << op.getOperationName() << " " << op.source() << ", " << op.dest() << "[" - << op.position() << " : " << op.position().getType() << "]"; - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.dest().getType(); -} - -static ParseResult parseInsertElementOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType source, dest, position; - Type positionType; - VectorType destType; - if (parser.parseOperand(source) || parser.parseComma() || - parser.parseOperand(dest) || parser.parseLSquare() || - parser.parseOperand(position) || parser.parseColonType(positionType) || - parser.parseRSquare() || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(destType)) - return failure(); - Type sourceType = destType.getElementType(); - return failure( - parser.resolveOperand(source, sourceType, result.operands) || - parser.resolveOperand(dest, destType, result.operands) || - parser.resolveOperand(position, positionType, result.operands) || - parser.addTypeToList(destType, result.types)); -} - static LogicalResult verify(InsertElementOp op) { auto dstVectorType = op.getDestVectorType(); if (dstVectorType.getRank() != 1) diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -226,7 +226,7 @@ // Integer comparisons are not recognized for float types. func @func_with_ops(f32, f32) { ^bb0(%a : f32, %b : f32): - %r = cmpi "eq", %a, %b : f32 // expected-error {{operand #0 must be integer-like}} + %r = cmpi "eq", %a, %b : f32 // expected-error {{'lhs' must be integer-like, but got 'f32'}} } // ----- @@ -298,13 +298,13 @@ // ----- func @invalid_select_shape(%cond : i1, %idx : () -> ()) { - // expected-error@+1 {{expected type with valid i1 shape}} + // expected-error@+1 {{'result' must be integer-like or floating-point-like, but got '() -> ()'}} %sel = select %cond, %idx, %idx : () -> () // ----- func @invalid_cmp_shape(%idx : () -> ()) { - // expected-error@+1 {{expected type with valid i1 shape}} + // expected-error@+1 {{'lhs' must be integer-like, but got '() -> ()'}} %cmp = cmpi "eq", %idx, %idx : () -> () // ----- @@ -340,7 +340,7 @@ // ----- func @invalid_cmp_attr(%idx : i32) { - // expected-error@+1 {{expected string comparison predicate attribute}} + // expected-error@+1 {{invalid kind of attribute specified}} %cmp = cmpi i1, %idx, %idx : i32 // ----- diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -219,16 +219,26 @@ void setBuilderIdx(int idx) { builderIdx = idx; } /// Get the variable this type is resolved to, or None. - Optional getVariable() const { return variableName; } - void setVariable(StringRef variable) { variableName = variable; } + const NamedTypeConstraint *getVariable() const { return variable; } + Optional getVarTransformer() const { + return variableTransformer; + } + void setVariable(const NamedTypeConstraint *var, + Optional transformer) { + variable = var; + variableTransformer = transformer; + } private: /// If the type is resolved with a buildable type, this is the index into /// 'buildableTypes' in the parent format. Optional builderIdx; /// If the type is resolved based upon another operand or result, this is - /// the name of the variable that this type is resolved to. - Optional variableName; + /// the variable that this type is resolved to. + const NamedTypeConstraint *variable; + /// If the type is resolved based upon another operand or result, this is + /// a transformer to apply to the variable when resolving. + Optional variableTransformer; }; OperationFormat(const Operator &op) @@ -487,6 +497,34 @@ void OperationFormat::genParserTypeResolution(Operator &op, OpMethodBody &body) { + // If any of type resolutions use transformed variables, make sure that the + // types of those variables are resolved. + SmallPtrSet verifiedVariables; + FmtContext verifierFCtx; + for (TypeResolution &resolver : + llvm::concat(resultTypes, operandTypes)) { + Optional transformer = resolver.getVarTransformer(); + if (!transformer) + continue; + // Ensure that we don't verify the same variables twice. + const NamedTypeConstraint *variable = resolver.getVariable(); + if (!verifiedVariables.insert(variable).second) + continue; + + auto constraint = variable->constraint; + body << " for (Type type : " << variable->name << "Types) {\n" + << " (void)type;\n" + << " if (!(" + << tgfmt(constraint.getConditionTemplate(), + &verifierFCtx.withSelf("type")) + << ")) {\n" + << formatv(" return parser.emitError(parser.getNameLoc()) << " + "\"'{0}' must be {1}, but got \" << type;\n", + variable->name, constraint.getDescription()) + << " }\n" + << " }\n"; + } + // Initialize the set of buildable types. if (!buildableTypes.empty()) { body << " Builder &builder = parser.getBuilder();\n"; @@ -498,18 +536,27 @@ << tgfmt(it.first, &typeBuilderCtx) << ";\n"; } + // Emit the code necessary for a type resolver. + auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) { + if (Optional val = resolver.getBuilderIdx()) { + body << "odsBuildableType" << *val; + } else if (const NamedTypeConstraint *var = resolver.getVariable()) { + if (Optional tform = resolver.getVarTransformer()) + body << tgfmt(*tform, &FmtContext().withSelf(var->name + "Types[0]")); + else + body << var->name << "Types"; + } else { + body << curVar << "Types"; + } + }; + // Resolve each of the result types. if (allResultTypes) { body << " result.addTypes(allResultTypes);\n"; } else { for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { body << " result.addTypes("; - if (Optional val = resultTypes[i].getBuilderIdx()) - body << "odsBuildableType" << *val; - else if (Optional var = resultTypes[i].getVariable()) - body << *var << "Types"; - else - body << op.getResultName(i) << "Types"; + emitTypeResolver(resultTypes[i], op.getResultName(i)); body << ");\n"; } } @@ -552,25 +599,19 @@ if (hasAllOperands) { body << " if (parser.resolveOperands(allOperands, "; - auto emitOperandType = [&](int idx) { - if (Optional val = operandTypes[idx].getBuilderIdx()) - body << "ArrayRef(odsBuildableType" << *val << ")"; - else if (Optional var = operandTypes[idx].getVariable()) - body << *var << "Types"; - else - body << op.getOperand(idx).name << "Types"; - }; - // Group all of the operand types together to perform the resolution all at // once. Use llvm::concat to perform the merge. llvm::concat does not allow // the case of a single range, so guard it here. if (op.getNumOperands() > 1) { body << "llvm::concat("; - interleaveComma(llvm::seq(0, op.getNumOperands()), body, - emitOperandType); + interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) { + body << "ArrayRef("; + emitTypeResolver(operandTypes[i], op.getOperand(i).name); + body << ")"; + }); body << ")"; } else { - emitOperandType(/*idx=*/0); + emitTypeResolver(operandTypes.front(), op.getOperand(0).name); } body << ", allOperandLoc, result.operands))\n" @@ -583,13 +624,12 @@ for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) { NamedTypeConstraint &operand = op.getOperand(i); body << " if (parser.resolveOperands(" << operand.name << "Operands, "; - if (Optional val = operandTypes[i].getBuilderIdx()) - body << "odsBuildableType" << *val << ", "; - else if (Optional var = operandTypes[i].getVariable()) - body << *var << "Types, " << operand.name << "OperandsLoc, "; - else - body << operand.name << "Types, " << operand.name << "OperandsLoc, "; - body << "result.operands))\n return failure();\n"; + emitTypeResolver(operandTypes[i], operand.name); + + // If this isn't a buildable type, verify the sizes match by adding the loc. + if (!operandTypes[i].getBuilderIdx()) + body << ", " << operand.name << "OperandsLoc"; + body << ", result.operands))\n return failure();\n"; } } @@ -954,18 +994,30 @@ LogicalResult parse(); private: + /// This struct represents a type resolution instance. It includes a specific + /// type as well as an optional transformer to apply to that type in order to + /// properly resolve the type of a variable. + struct TypeResolutionInstance { + const NamedTypeConstraint *type; + Optional transformer; + }; + /// Given the values of an `AllTypesMatch` trait, check for inferrable type /// resolution. void handleAllTypesMatchConstraint( ArrayRef values, - llvm::StringMap &variableTyResolver); + llvm::StringMap &variableTyResolver); /// Check for inferrable type resolution given all operands, and or results, /// have the same type. If 'includeResults' is true, the results also have the /// same type as all of the operands. void handleSameTypesConstraint( - llvm::StringMap &variableTyResolver, + llvm::StringMap &variableTyResolver, bool includeResults); + /// Returns an argument with the given name that has been seen within the + /// format. + const NamedTypeConstraint *findSeenArg(StringRef name); + /// Parse a specific element. LogicalResult parseElement(std::unique_ptr &element, bool isTopLevel); @@ -1044,16 +1096,21 @@ return emitError(loc, "format missing 'attr-dict' directive"); // Check for any type traits that we can use for inferring types. - llvm::StringMap variableTyResolver; + llvm::StringMap variableTyResolver; for (const OpTrait &trait : op.getTraits()) { const llvm::Record &def = trait.getDef(); - if (def.isSubClassOf("AllTypesMatch")) + if (def.isSubClassOf("AllTypesMatch")) { handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"), variableTyResolver); - else if (def.getName() == "SameTypeOperands") + } else if (def.getName() == "SameTypeOperands") { handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false); - else if (def.getName() == "SameOperandsAndResultType") + } else if (def.getName() == "SameOperandsAndResultType") { handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true); + } else if (def.isSubClassOf("TypesMatchWith")) { + if (const auto *lhsArg = findSeenArg(def.getValueAsString("lhs"))) + variableTyResolver[def.getValueAsString("rhs")] = { + lhsArg, def.getValueAsString("transformer")}; + } } // Check that all of the result types can be inferred. @@ -1066,7 +1123,8 @@ // Check to see if we can infer this type from another variable. auto varResolverIt = variableTyResolver.find(op.getResultName(i)); if (varResolverIt != variableTyResolver.end()) { - fmt.resultTypes[i].setVariable(varResolverIt->second->name); + fmt.resultTypes[i].setVariable(varResolverIt->second.type, + varResolverIt->second.transformer); continue; } @@ -1102,7 +1160,8 @@ // Check to see if we can infer this type from another variable. auto varResolverIt = variableTyResolver.find(op.getOperand(i).name); if (varResolverIt != variableTyResolver.end()) { - fmt.operandTypes[i].setVariable(varResolverIt->second->name); + fmt.operandTypes[i].setVariable(varResolverIt->second.type, + varResolverIt->second.transformer); continue; } @@ -1121,30 +1180,23 @@ void FormatParser::handleAllTypesMatchConstraint( ArrayRef values, - llvm::StringMap &variableTyResolver) { + llvm::StringMap &variableTyResolver) { for (unsigned i = 0, e = values.size(); i != e; ++i) { // Check to see if this value matches a resolved operand or result type. - const NamedTypeConstraint *arg = nullptr; - if ((arg = findArg(op.getOperands(), values[i]))) { - if (!seenOperandTypes.test(arg - op.operand_begin())) - continue; - } else if ((arg = findArg(op.getResults(), values[i]))) { - if (!seenResultTypes.test(arg - op.result_begin())) - continue; - } else { + const NamedTypeConstraint *arg = findSeenArg(values[i]); + if (!arg) continue; - } // Mark this value as the type resolver for the other variables. for (unsigned j = 0; j != i; ++j) - variableTyResolver[values[j]] = arg; + variableTyResolver[values[j]] = {arg, llvm::None}; for (unsigned j = i + 1; j != e; ++j) - variableTyResolver[values[j]] = arg; + variableTyResolver[values[j]] = {arg, llvm::None}; } } void FormatParser::handleSameTypesConstraint( - llvm::StringMap &variableTyResolver, + llvm::StringMap &variableTyResolver, bool includeResults) { const NamedTypeConstraint *resolver = nullptr; int resolvedIt = -1; @@ -1160,14 +1212,22 @@ // Set the resolvers for each operand and result. for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) if (!seenOperandTypes.test(i) && !op.getOperand(i).name.empty()) - variableTyResolver[op.getOperand(i).name] = resolver; + variableTyResolver[op.getOperand(i).name] = {resolver, llvm::None}; if (includeResults) { for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) if (!seenResultTypes.test(i) && !op.getResultName(i).empty()) - variableTyResolver[op.getResultName(i)] = resolver; + variableTyResolver[op.getResultName(i)] = {resolver, llvm::None}; } } +const NamedTypeConstraint *FormatParser::findSeenArg(StringRef name) { + if (auto *arg = findArg(op.getOperands(), name)) + return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr; + if (auto *arg = findArg(op.getResults(), name)) + return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr; + return nullptr; +} + LogicalResult FormatParser::parseElement(std::unique_ptr &element, bool isTopLevel) { // Directives. @@ -1191,7 +1251,8 @@ StringRef name = varTok.getSpelling().drop_front(); llvm::SMLoc loc = varTok.getLoc(); - // Check that the parsed argument is something actually registered on the op. + // Check that the parsed argument is something actually registered on the + // op. /// Attributes if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) { if (isTopLevel && !seenAttrs.insert(attr).second)