diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -83,6 +83,10 @@ // DAG `tree` as an attribute. void emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent); + // Emits C++ for checking a match with a corresponding failure diagnostic. + void emitMatchCheck(int depth, const FmtObjectBase &matchFmt, + const llvm::formatv_object_base &failureFmt); + //===--------------------------------------------------------------------===// // Rewrite utilities //===--------------------------------------------------------------------===// @@ -287,7 +291,8 @@ // Only need to verify if the matcher's type is different from the one // of op definition. - if (operand->constraint != matcher.getAsConstraint()) { + Constraint constraint = matcher.getAsConstraint(); + if (operand->constraint != constraint) { if (operand->isVariadic()) { auto error = formatv( "further constrain op {0}'s variadic operand #{1} unsupported now", @@ -297,10 +302,13 @@ auto self = formatv("(*castedOp{0}.getODSOperands({1}).begin()).getType()", depth, argIndex); - os.indent(indent) << "if (!(" - << std::string(tgfmt(matcher.getConditionTemplate(), - &fmtCtx.withSelf(self))) - << ")) return failure();\n"; + emitMatchCheck( + depth, + tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), + formatv("\"operand {0} of op '{1}' failed to satisfy constraint: " + "'{2}'\"", + operand - op.operand_begin(), op.getOperationName(), + constraint.getDescription())); } } @@ -321,7 +329,6 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent) { - Operator &op = tree.getDialectOp(opMap); auto *namedAttr = op.getArg(argIndex).get(); const auto &attr = namedAttr->attr; @@ -344,7 +351,11 @@ // should just capture a mlir::Attribute() to signal the missing state. // That is precisely what getAttr() returns on missing attributes. } else { - os.indent(indent) << "if (!tblgen_attr) return failure();\n"; + emitMatchCheck(depth, tgfmt("tblgen_attr", &fmtCtx), + formatv("\"expected op '{0}' to have attribute '{1}' " + "of type '{2}'\"", + op.getOperationName(), namedAttr->name, + attr.getStorageType())); } auto matcher = tree.getArgAsLeaf(argIndex); @@ -357,10 +368,13 @@ // If a constraint is specified, we need to generate C++ statements to // check the constraint. - os.indent(indent) << "if (!(" - << std::string(tgfmt(matcher.getConditionTemplate(), - &fmtCtx.withSelf("tblgen_attr"))) - << ")) return failure();\n"; + emitMatchCheck( + depth, + tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")), + formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: " + "{2}\"", + op.getOperationName(), namedAttr->name, + matcher.getAsConstraint().getDescription())); } // Capture the value @@ -374,6 +388,22 @@ os.indent(indent) << "}\n"; } +void PatternEmitter::emitMatchCheck( + int depth, const FmtObjectBase &matchFmt, + const llvm::formatv_object_base &failureFmt) { + // {0} The match depth(used to get the operation that failed to match). + // {1} The format for the match string. + // {2} The format for the failure string. + const char *matchStr = R"( + if (!({1})) { + return rewriter.notifyMatchFailure(op{0}, [&](::mlir::Diagnostic &diag) { + diag << {2}; + }); + })"; + os << llvm::formatv(matchStr, depth, matchFmt.str(), failureFmt.str()) + << "\n"; +} + void PatternEmitter::emitMatchLogic(DagNode tree) { LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); emitOpMatch(tree, 0); @@ -383,13 +413,14 @@ auto &entities = appliedConstraint.entities; auto condition = constraint.getConditionTemplate(); - auto cmd = "if (!({0})) return failure();\n"; - if (isa(constraint)) { auto self = formatv("({0}.getType())", symbolInfoMap.getValueAndRangeUse(entities.front())); - os.indent(4) << formatv(cmd, - tgfmt(condition, &fmtCtx.withSelf(self.str()))); + emitMatchCheck( + /*depth=*/0, tgfmt(condition, &fmtCtx.withSelf(self.str())), + formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"", + entities.front().data(), constraint.getDescription().data())); + } else if (isa(constraint)) { PrintFatalError( loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); @@ -408,9 +439,13 @@ self = symbolInfoMap.getValueAndRangeUse(self); for (; i < 4; ++i) names.push_back(""); - os.indent(4) << formatv(cmd, - tgfmt(condition, &fmtCtx.withSelf(self), names[0], - names[1], names[2], names[3])); + emitMatchCheck(/*depth=*/0, + tgfmt(condition, &fmtCtx.withSelf(self), names[0], + names[1], names[2], names[3]), + formatv("\"entities '{0}' failed to satisfy constraint: " + "{1}\"", + llvm::join(entities, ", "), + constraint.getDescription().data())); } } LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");