diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -400,7 +400,8 @@ // Misc operations. def LLVM_SelectOp - : LLVM_OneResultOp<"select", [NoSideEffect]>, + : LLVM_OneResultOp<"select", + [NoSideEffect, AllTypesMatch<["trueValue", "falseValue", "res"]>]>, Arguments<(ins LLVM_Type:$condition, LLVM_Type:$trueValue, LLVM_Type:$falseValue)>, LLVM_Builder< @@ -410,8 +411,7 @@ "Value rhs", [{ build(b, result, lhs.getType(), condition, lhs, rhs); }]>]; - let parser = [{ return parseSelectOp(parser, result); }]; - let printer = [{ printSelectOp(p, *this); }]; + let assemblyFormat = "operands attr-dict `:` type($condition) `,` type($res)"; } // Terminators. diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td @@ -99,7 +99,8 @@ // ----- -def SPV_BitFieldInsertOp : SPV_Op<"BitFieldInsert", [NoSideEffect]> { +def SPV_BitFieldInsertOp : SPV_Op<"BitFieldInsert", + [NoSideEffect, AllTypesMatch<["base", "insert", "result"]>]> { let summary = [{ Make a copy of an object, with a modified bit field that comes from another object. @@ -163,6 +164,12 @@ let results = (outs SPV_ScalarOrVectorOf:$result ); + + let verifier = [{ return success(); }]; + + let assemblyFormat = [{ + operands attr-dict `:` type($base) `,` type($offset) `,` type($count) + }]; } // ----- diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td @@ -794,7 +794,8 @@ // ----- -def SPV_SelectOp : SPV_Op<"Select", [NoSideEffect]> { +def SPV_SelectOp : SPV_Op<"Select", + [NoSideEffect, AllTypesMatch<["true_value", "false_value", "result"]>]> { let summary = [{ Select between two objects. Before version 1.4, results are only computed per component. @@ -851,6 +852,10 @@ let builders = [OpBuilder<[{Builder *builder, OperationState &state, Value cond, Value trueValue, Value falseValue}]>]; + + let assemblyFormat = [{ + operands attr-dict `:` type($condition) `,` type($result) + }]; } // ----- diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -407,10 +407,9 @@ Vector_Op<"insert", [NoSideEffect, PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, - PredOpTrait<"dest operand and result have same type", - TCresIsSameAsOpBase<0, 1>>]>, + AllTypesMatch<["dest", "res"]>]>, Arguments<(ins AnyType:$source, AnyVector:$dest, I64ArrayAttr:$position)>, - Results<(outs AnyVector)> { + Results<(outs AnyVector:$res)> { let summary = "insert operation"; let description = [{ Takes an n-D source vector, an (n+k)-D destination vector and a k-D position @@ -425,6 +424,10 @@ f32 into vector<4x8x16xf32> ``` }]; + let assemblyFormat = [{ + $source `,` $dest $position attr-dict `:` type($source) `into` type($dest) + }]; + let builders = [OpBuilder< "Builder *builder, OperationState &result, Value source, " # "Value dest, ArrayRef">]; @@ -497,11 +500,10 @@ Vector_Op<"insert_strided_slice", [NoSideEffect, PredOpTrait<"operand #0 and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, - PredOpTrait<"dest operand and result have same type", - TCresIsSameAsOpBase<0, 1>>]>, + AllTypesMatch<["dest", "res"]>]>, Arguments<(ins AnyVector:$source, AnyVector:$dest, I64ArrayAttr:$offsets, I64ArrayAttr:$strides)>, - Results<(outs AnyVector)> { + Results<(outs AnyVector:$res)> { let summary = "strided_slice operation"; let description = [{ Takes a k-D source vector, an n-D destination vector (n >= k), n-sized @@ -522,6 +524,11 @@ vector<2x4xf32> into vector<16x4x8xf32> ``` }]; + + let assemblyFormat = [{ + $source `,` $dest attr-dict `:` type($source) `into` type($dest) + }]; + let builders = [OpBuilder< "Builder *builder, OperationState &result, Value source, Value dest, " # "ArrayRef offsets, ArrayRef strides">]; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1650,7 +1650,9 @@ string description> : PredOpTrait< "all of {" # StrJoin.result # "} have same " # description, - AllMatchSameOperatorPred>; + AllMatchSameOperatorPred> { + list values = names; +} class AllElementCountsMatch names> : AllMatchSameOperatorTrait.result, diff --git a/mlir/include/mlir/TableGen/OpTrait.h b/mlir/include/mlir/TableGen/OpTrait.h --- a/mlir/include/mlir/TableGen/OpTrait.h +++ b/mlir/include/mlir/TableGen/OpTrait.h @@ -49,6 +49,9 @@ Kind getKind() const { return kind; } + // Returns the Tablegen definition this operator was constructed from. + const llvm::Record &getDef() const { return *def; } + protected: // The TableGen definition of this trait. const llvm::Record *def; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -556,40 +556,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::SelectOp. -//===----------------------------------------------------------------------===// - -static void printSelectOp(OpAsmPrinter &p, SelectOp &op) { - p << op.getOperationName() << ' ' << op.condition() << ", " << op.trueValue() - << ", " << op.falseValue(); - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.condition().getType() << ", " << op.trueValue().getType(); -} - -// ::= `llvm.select` ssa-use `,` ssa-use `,` ssa-use -// attribute-dict? `:` type, type -static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType condition, trueValue, falseValue; - Type conditionType, argType; - - if (parser.parseOperand(condition) || parser.parseComma() || - parser.parseOperand(trueValue) || parser.parseComma() || - parser.parseOperand(falseValue) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(conditionType) || parser.parseComma() || - parser.parseType(argType)) - return failure(); - - if (parser.resolveOperand(condition, conditionType, result.operands) || - parser.resolveOperand(trueValue, argType, result.operands) || - parser.resolveOperand(falseValue, argType, result.operands)) - return failure(); - - result.addTypes(argType); - return success(); -} - //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::BrOp. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1067,54 +1067,6 @@ results.insert(context); } -//===----------------------------------------------------------------------===// -// spv.BitFieldInsert -//===----------------------------------------------------------------------===// - -static ParseResult parseBitFieldInsertOp(OpAsmParser &parser, - OperationState &state) { - SmallVector operandInfo; - Type baseType; - Type offsetType; - Type countType; - auto loc = parser.getCurrentLocation(); - - if (parser.parseOperandList(operandInfo, 4) || parser.parseColon() || - parser.parseType(baseType) || parser.parseComma() || - parser.parseType(offsetType) || parser.parseComma() || - parser.parseType(countType) || - parser.resolveOperands(operandInfo, - {baseType, baseType, offsetType, countType}, loc, - state.operands)) { - return failure(); - } - state.addTypes(baseType); - return success(); -} - -static void print(spirv::BitFieldInsertOp bitFieldInsertOp, - OpAsmPrinter &printer) { - printer << spirv::BitFieldInsertOp::getOperationName() << ' ' - << bitFieldInsertOp.getOperands() << " : " - << bitFieldInsertOp.base().getType() << ", " - << bitFieldInsertOp.offset().getType() << ", " - << bitFieldInsertOp.count().getType(); -} - -static LogicalResult verify(spirv::BitFieldInsertOp bitFieldOp) { - auto baseType = bitFieldOp.base().getType(); - auto insertType = bitFieldOp.insert().getType(); - auto resultType = bitFieldOp.getResult().getType(); - - if ((baseType != insertType) || (baseType != resultType)) { - return bitFieldOp.emitError("expected the same type for the base operand, " - "insert operand, and " - "result, but provided ") - << baseType << ", " << insertType << " and " << resultType; - } - return success(); -} - //===----------------------------------------------------------------------===// // spv.BranchOp //===----------------------------------------------------------------------===// @@ -2524,42 +2476,9 @@ build(builder, state, trueValue.getType(), cond, trueValue, falseValue); } -static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &state) { - OpAsmParser::OperandType condition; - SmallVector operands; - SmallVector types; - auto loc = parser.getCurrentLocation(); - if (parser.parseOperand(condition) || parser.parseComma() || - parser.parseOperandList(operands, 2) || - parser.parseColonTypeList(types)) { - return failure(); - } - if (types.size() != 2) { - return parser.emitError( - loc, "need exactly two trailing types for select condition and object"); - } - if (parser.resolveOperand(condition, types[0], state.operands) || - parser.resolveOperands(operands, types[1], state.operands)) { - return failure(); - } - return parser.addTypesToList(types[1], state.types); -} - -static void print(spirv::SelectOp op, OpAsmPrinter &printer) { - printer << spirv::SelectOp::getOperationName() << " " << op.getOperands() - << " : " << op.condition().getType() << ", " << op.result().getType(); -} - static LogicalResult verify(spirv::SelectOp op) { - auto resultTy = op.result().getType(); - if (op.true_value().getType() != resultTy) { - return op.emitOpError("result type and true value type must be the same"); - } - if (op.false_value().getType() != resultTy) { - return op.emitOpError("result type and false value type must be the same"); - } if (auto conditionTy = op.condition().getType().dyn_cast()) { - auto resultVectorTy = resultTy.dyn_cast(); + auto resultVectorTy = op.result().getType().dyn_cast(); if (!resultVectorTy) { return op.emitOpError("result expected to be of vector type when " "condition is of vector type"); diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -700,31 +700,6 @@ result.addAttribute(getPositionAttrName(), positionAttr); } -static void print(OpAsmPrinter &p, InsertOp op) { - p << op.getOperationName() << " " << op.source() << ", " << op.dest() - << op.position(); - p.printOptionalAttrDict(op.getAttrs(), {InsertOp::getPositionAttrName()}); - p << " : " << op.getSourceType() << " into " << op.getDestVectorType(); -} - -static ParseResult parseInsertOp(OpAsmParser &parser, OperationState &result) { - SmallVector attrs; - OpAsmParser::OperandType source, dest; - Type sourceType; - VectorType destType; - Attribute attr; - return failure(parser.parseOperand(source) || parser.parseComma() || - parser.parseOperand(dest) || - parser.parseAttribute(attr, InsertOp::getPositionAttrName(), - result.attributes) || - parser.parseOptionalAttrDict(attrs) || - parser.parseColonType(sourceType) || - parser.parseKeywordType("into", destType) || - parser.resolveOperand(source, sourceType, result.operands) || - parser.resolveOperand(dest, destType, result.operands) || - parser.addTypeToList(destType, result.types)); -} - static LogicalResult verify(InsertOp op) { auto positionAttr = op.position().getValue(); if (positionAttr.empty()) @@ -793,27 +768,6 @@ result.addAttribute(getStridesAttrName(), stridesAttr); } -static void print(OpAsmPrinter &p, InsertStridedSliceOp op) { - p << op.getOperationName() << " " << op.source() << ", " << op.dest() << " "; - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.getSourceVectorType() << " into " << op.getDestVectorType(); -} - -static ParseResult parseInsertStridedSliceOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType source, dest; - VectorType sourceVectorType, destVectorType; - return failure( - parser.parseOperand(source) || parser.parseComma() || - parser.parseOperand(dest) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(sourceVectorType) || - parser.parseKeywordType("into", destVectorType) || - parser.resolveOperand(source, sourceVectorType, result.operands) || - parser.resolveOperand(dest, destVectorType, result.operands) || - parser.addTypeToList(destVectorType, result.types)); -} - // TODO(ntv) Should be moved to Tablegen Confined attributes. template static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -227,7 +227,7 @@ // ----- func @bit_field_insert_invalid_insert_type(%base: vector<3xi32>, %insert: vector<2xi32>, %offset: i32, %count: i16) -> vector<3xi32> { - // expected-error @+1 {{expected the same type for the base operand, insert operand, and result, but provided 'vector<3xi32>', 'vector<2xi32>' and 'vector<3xi32>'}} + // expected-error @+1 {{all of {base, insert, result} have same type}} %0 = "spv.BitFieldInsert" (%base, %insert, %offset, %count) : (vector<3xi32>, vector<2xi32>, i32, i16) -> vector<3xi32> spv.ReturnValue %0 : vector<3xi32> } @@ -856,7 +856,7 @@ func @select_op(%arg0: i1) -> () { %0 = spv.constant 2 : i32 %1 = spv.constant 3 : i32 - // expected-error @+1 {{need exactly two trailing types for select condition and object}} + // expected-error @+2 {{expected ','}} %2 = spv.Select %arg0, %0, %1 : i1 return } @@ -886,7 +886,7 @@ func @select_op(%arg1: vector<4xi1>) -> () { %0 = spv.constant dense<[2.0, 3.0, 4.0]> : vector<3xf32> %1 = spv.constant dense<[5, 6, 7]> : vector<3xi32> - // expected-error @+1 {{op result type and true value type must be the same}} + // expected-error @+1 {{all of {true_value, false_value, result} have same type}} %2 = "spv.Select"(%arg1, %0, %1) : (vector<4xi1>, vector<3xf32>, vector<3xi32>) -> vector<3xi32> return } @@ -896,7 +896,7 @@ func @select_op(%arg1: vector<4xi1>) -> () { %0 = spv.constant dense<[2.0, 3.0, 4.0]> : vector<3xf32> %1 = spv.constant dense<[5, 6, 7]> : vector<3xi32> - // expected-error @+1 {{op result type and false value type must be the same}} + // expected-error @+1 {{all of {true_value, false_value, result} have same type}} %2 = "spv.Select"(%arg1, %1, %0) : (vector<4xi1>, vector<3xi32>, vector<3xf32>) -> vector<3xi32> return } diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -7,7 +7,8 @@ def TestDialect : Dialect { let name = "test"; } -class TestFormat_Op : Op { +class TestFormat_Op traits = []> + : Op { let assemblyFormat = fmt; } @@ -234,3 +235,24 @@ operands functional-type(operands, results) attr-dict }]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>; +// Check that we can infer type equalities from certain traits. +def ZCoverageValidD : TestFormat_Op<"variable_valid_d", [{ + operands type($result) attr-dict +}], [SameOperandsAndResultType]>, Arguments<(ins AnyMemRef:$operand)>, + Results<(outs AnyMemRef:$result)>; +def ZCoverageValidE : TestFormat_Op<"variable_valid_e", [{ + $operand type($operand) attr-dict +}], [SameOperandsAndResultType]>, Arguments<(ins AnyMemRef:$operand)>, + Results<(outs AnyMemRef:$result)>; +def ZCoverageValidF : TestFormat_Op<"variable_valid_f", [{ + operands type($other) attr-dict +}], [SameTypeOperands]>, Arguments<(ins AnyMemRef:$operand, AnyMemRef:$other)>; +def ZCoverageValidG : TestFormat_Op<"variable_valid_g", [{ + operands type($other) attr-dict +}], [AllTypesMatch<["operand", "other"]>]>, + Arguments<(ins AnyMemRef:$operand, AnyMemRef:$other)>; +def ZCoverageValidH : TestFormat_Op<"variable_valid_h", [{ + operands type($result) attr-dict +}], [AllTypesMatch<["operand", "result"]>]>, + Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>; + diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -202,10 +202,32 @@ namespace { struct OperationFormat { + /// This class represents a specific resolver for an operand or result type. + class TypeResolution { + public: + TypeResolution() = default; + + /// Get the index into the buildable types for this type, or None. + Optional getBuilderIdx() const { return builderIdx; } + void setBuilderIdx(int idx) { builderIdx = idx; } + + /// Get the variable this type is resolved to, or None. + Optional getVariable() const { return variableName; } + void setVariable(StringRef variable) { variableName = variable; } + + private: + /// If the type is resolved with a buildable type, this is the index into + /// 'buildableTypes' in the parent format. + Optional builderIdx; + /// If the type is resolved based upon another operand or result, this is + /// the name of the variable that this type is resolved to. + Optional variableName; + }; + OperationFormat(const Operator &op) : allOperandTypes(false), allResultTypes(false) { - buildableOperandTypes.resize(op.getNumOperands(), llvm::None); - buildableResultTypes.resize(op.getNumResults(), llvm::None); + operandTypes.resize(op.getNumOperands(), TypeResolution()); + resultTypes.resize(op.getNumResults(), TypeResolution()); } /// Generate the operation parser from this format. @@ -228,7 +250,7 @@ llvm::MapVector> buildableTypes; /// The index of the buildable type, if valid, for every operand and result. - std::vector> buildableOperandTypes, buildableResultTypes; + std::vector operandTypes, resultTypes; }; } // end anonymous namespace @@ -398,8 +420,10 @@ } else { for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { body << " result.addTypes("; - if (Optional val = buildableResultTypes[i]) + if (Optional val = resultTypes[i].getBuilderIdx()) body << "odsBuildableType" << *val; + else if (Optional var = resultTypes[i].getVariable()) + body << *var << "Types"; else body << op.getResultName(i) << "Types"; body << ");\n"; @@ -450,8 +474,10 @@ if (op.getNumOperands() > 1) { body << "llvm::concat("; interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) { - if (Optional val = buildableOperandTypes[i]) + if (Optional val = operandTypes[i].getBuilderIdx()) body << "ArrayRef(odsBuildableType" << *val << ")"; + else if (Optional var = operandTypes[i].getVariable()) + body << *var << "Types"; else body << op.getOperand(i).name << "Types"; }); @@ -470,8 +496,10 @@ for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) { NamedTypeConstraint &operand = op.getOperand(i); body << " if (parser.resolveOperands(" << operand.name << "Operands, "; - if (Optional val = buildableOperandTypes[i]) + if (Optional val = operandTypes[i].getBuilderIdx()) body << "odsBuildableType" << *val << ", "; + else if (Optional var = operandTypes[i].getVariable()) + body << *var << "Types, " << operand.name << "OperandsLoc, "; else body << operand.name << "Types, " << operand.name << "OperandsLoc, "; body << "result.operands))\n return failure();\n"; @@ -803,6 +831,13 @@ // FormatParser //===----------------------------------------------------------------------===// +/// Function to find an element within the given range that has the same name as +/// 'name'. +template static auto findArg(RangeT &&range, StringRef name) { + auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; }); + return it != range.end() ? &*it : nullptr; +} + namespace { /// This class implements a parser for an instance of an operation assembly /// format. @@ -817,6 +852,18 @@ LogicalResult parse(); private: + /// Given the values of an `AllTypesMatch` trait, check for inferrable type + /// resolution. + void handleAllTypesMatchConstraint( + ArrayRef values, + llvm::StringMap &variableTyResolver); + /// Check for inferrable type resolution given all operands, and or results, + /// have the same type. If 'includeResults' is true, the results also have the + /// same type as all of the operands. + void handleSameTypesConstraint( + llvm::StringMap &variableTyResolver, + bool includeResults); + /// Parse a specific element. LogicalResult parseElement(std::unique_ptr &element, bool isTopLevel); @@ -870,8 +917,8 @@ OperationFormat &fmt; Operator &op; - // The following are various bits of format state used for verification during - // parsing. + // The following are various bits of format state used for verification + // during parsing. bool hasAllOperands = false, hasAttrDict = false; llvm::SmallBitVector seenOperandTypes, seenResultTypes; llvm::DenseSet seenOperands; @@ -894,6 +941,19 @@ if (!hasAttrDict) return emitError(loc, "format missing 'attr-dict' directive"); + // Check for any type traits that we can use for inferring types. + llvm::StringMap variableTyResolver; + for (const OpTrait &trait : op.getTraits()) { + const llvm::Record &def = trait.getDef(); + if (def.isSubClassOf("AllTypesMatch")) + handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"), + variableTyResolver); + else if (def.getName() == "SameTypeOperands") + handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false); + else if (def.getName() == "SameOperandsAndResultType") + handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true); + } + // Check that all of the result types can be inferred. auto &buildableTypes = fmt.buildableTypes; if (!fmt.allResultTypes) { @@ -901,6 +961,13 @@ if (seenResultTypes.test(i)) continue; + // Check to see if we can infer this type from another variable. + auto varResolverIt = variableTyResolver.find(op.getResultName(i)); + if (varResolverIt != variableTyResolver.end()) { + fmt.resultTypes[i].setVariable(varResolverIt->second->name); + continue; + } + // If the result is not variadic, allow for the case where the type has a // builder that we can use. NamedTypeConstraint &result = op.getResult(i); @@ -911,7 +978,7 @@ } // Note in the format that this result uses the custom builder. auto it = buildableTypes.insert({*builder, buildableTypes.size()}); - fmt.buildableResultTypes[i] = it.first->second; + fmt.resultTypes[i].setBuilderIdx(it.first->second); } } @@ -927,21 +994,78 @@ } // Check that the operand type is in the format, or that it can be inferred. - if (!fmt.allOperandTypes && !seenOperandTypes.test(i)) { - // Similarly to results, allow a custom builder for resolving the type if - // we aren't using the 'operands' directive. - Optional builder = operand.constraint.getBuilderCall(); - if (!builder || (hasAllOperands && operand.isVariadic())) { - return emitError(loc, "format missing instance of operand #" + - Twine(i) + "('" + operand.name + "') type"); - } - auto it = buildableTypes.insert({*builder, buildableTypes.size()}); - fmt.buildableOperandTypes[i] = it.first->second; + if (fmt.allOperandTypes || seenOperandTypes.test(i)) + continue; + + // Check to see if we can infer this type from another variable. + auto varResolverIt = variableTyResolver.find(op.getOperand(i).name); + if (varResolverIt != variableTyResolver.end()) { + fmt.operandTypes[i].setVariable(varResolverIt->second->name); + continue; + } + + // Similarly to results, allow a custom builder for resolving the type if + // we aren't using the 'operands' directive. + Optional builder = operand.constraint.getBuilderCall(); + if (!builder || (hasAllOperands && operand.isVariadic())) { + return emitError(loc, "format missing instance of operand #" + Twine(i) + + "('" + operand.name + "') type"); } + auto it = buildableTypes.insert({*builder, buildableTypes.size()}); + fmt.operandTypes[i].setBuilderIdx(it.first->second); } return success(); } +void FormatParser::handleAllTypesMatchConstraint( + ArrayRef values, + llvm::StringMap &variableTyResolver) { + for (unsigned i = 0, e = values.size(); i != e; ++i) { + // Check to see if this value matches a resolved operand or result type. + const NamedTypeConstraint *arg = nullptr; + if ((arg = findArg(op.getOperands(), values[i]))) { + if (!seenOperandTypes.test(arg - op.operand_begin())) + continue; + } else if ((arg = findArg(op.getResults(), values[i]))) { + if (!seenResultTypes.test(arg - op.result_begin())) + continue; + } else { + continue; + } + + // Mark this value as the type resolver for the other variables. + for (unsigned j = 0; j != i; ++j) + variableTyResolver[values[j]] = arg; + for (unsigned j = i + 1; j != e; ++j) + variableTyResolver[values[j]] = arg; + } +} + +void FormatParser::handleSameTypesConstraint( + llvm::StringMap &variableTyResolver, + bool includeResults) { + const NamedTypeConstraint *resolver = nullptr; + int resolvedIt = -1; + + // Check to see if there is an operand or result to use for the resolution. + if ((resolvedIt = seenOperandTypes.find_first()) != -1) + resolver = &op.getOperand(resolvedIt); + else if (includeResults && (resolvedIt = seenResultTypes.find_first()) != -1) + resolver = &op.getResult(resolvedIt); + else + return; + + // Set the resolvers for each operand and result. + for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) + if (!seenOperandTypes.test(i) && !op.getOperand(i).name.empty()) + variableTyResolver[op.getOperand(i).name] = resolver; + if (includeResults) { + for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) + if (!seenResultTypes.test(i) && !op.getResultName(i).empty()) + variableTyResolver[op.getResultName(i)] = resolver; + } +} + LogicalResult FormatParser::parseElement(std::unique_ptr &element, bool isTopLevel) { // Directives. @@ -965,23 +1089,16 @@ StringRef name = varTok.getSpelling().drop_front(); llvm::SMLoc loc = varTok.getLoc(); - // Functor used to find an element within the given range that has the same - // name as 'name'. - auto findArg = [&](auto &&range) { - auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; }); - return it != range.end() ? &*it : nullptr; - }; - // Check that the parsed argument is something actually registered on the op. /// Attributes - if (const NamedAttribute *attr = findArg(op.getAttributes())) { + if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) { if (isTopLevel && !seenAttrs.insert(attr).second) return emitError(loc, "attribute '" + name + "' is already bound"); element = std::make_unique(attr); return success(); } /// Operands - if (const NamedTypeConstraint *operand = findArg(op.getOperands())) { + if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) { if (isTopLevel) { if (hasAllOperands || !seenOperands.insert(operand).second) return emitError(loc, "operand '" + name + "' is already bound"); @@ -990,7 +1107,7 @@ return success(); } /// Results. - if (const NamedTypeConstraint *result = findArg(op.getResults())) { + if (const auto *result = findArg(op.getResults(), name)) { if (isTopLevel) return emitError(loc, "results can not be used at the top level"); element = std::make_unique(result);