diff --git a/mlir/include/mlir/Query/Matcher/ErrorBuilder.h b/mlir/include/mlir/Query/Matcher/ErrorBuilder.h --- a/mlir/include/mlir/Query/Matcher/ErrorBuilder.h +++ b/mlir/include/mlir/Query/Matcher/ErrorBuilder.h @@ -37,8 +37,12 @@ None, // Parser Errors + ParserChainedExprInvalidArg, + ParserChainedExprNoCloseParen, + ParserChainedExprNoOpenParen, ParserFailedToBuildMatcher, ParserInvalidToken, + ParserMalformedChainedExpr, ParserNoCloseParen, ParserNoCode, ParserNoComma, diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h --- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h +++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h @@ -63,8 +63,15 @@ bool match(Operation *op) const { return implementation->match(op); } + void setFunctionName(StringRef name) { functionName = name.str(); }; + + bool hasFunctionName() const { return !functionName.empty(); }; + + StringRef getFunctionName() const { return functionName; }; + private: llvm::IntrusiveRefCntPtr implementation; + std::string functionName; }; } // namespace mlir::query::matcher diff --git a/mlir/lib/Query/Matcher/Diagnostics.cpp b/mlir/lib/Query/Matcher/Diagnostics.cpp --- a/mlir/lib/Query/Matcher/Diagnostics.cpp +++ b/mlir/lib/Query/Matcher/Diagnostics.cpp @@ -57,6 +57,14 @@ return "Unexpected end of code."; case ErrorType::ParserOverloadedType: return "Input value has unresolved overloaded type: $0"; + case ErrorType::ParserMalformedChainedExpr: + return "Period not followed by valid chained call."; + case ErrorType::ParserChainedExprInvalidArg: + return "Missing/Invalid argument for the chained call."; + case ErrorType::ParserChainedExprNoCloseParen: + return "Missing ')' for the chained call."; + case ErrorType::ParserChainedExprNoOpenParen: + return "Missing '(' for the chained call."; case ErrorType::ParserFailedToBuildMatcher: return "Failed to build matcher: $0."; diff --git a/mlir/lib/Query/Matcher/Parser.h b/mlir/lib/Query/Matcher/Parser.h --- a/mlir/lib/Query/Matcher/Parser.h +++ b/mlir/lib/Query/Matcher/Parser.h @@ -64,10 +64,9 @@ // Process a matcher expression. The caller takes ownership of the Matcher // object returned. - virtual VariantMatcher - actOnMatcherExpression(MatcherCtor ctor, SourceRange nameRange, - llvm::ArrayRef args, - Diagnostics *error) = 0; + virtual VariantMatcher actOnMatcherExpression( + MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName, + llvm::ArrayRef args, Diagnostics *error) = 0; // Look up a matcher by name in the matcher name found by the parser. virtual std::optional @@ -93,10 +92,11 @@ std::optional lookupMatcherCtor(llvm::StringRef matcherName) override; - VariantMatcher actOnMatcherExpression(MatcherCtor ctor, - SourceRange nameRange, - llvm::ArrayRef args, - Diagnostics *error) override; + VariantMatcher actOnMatcherExpression(MatcherCtor Ctor, + SourceRange NameRange, + StringRef functionName, + ArrayRef Args, + Diagnostics *Error) override; std::vector getAcceptedCompletionTypes( llvm::ArrayRef> context) override; @@ -153,6 +153,8 @@ Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry, const NamedValueMap *namedValues, Diagnostics *error); + bool parseChainedExpression(std::string &argument); + bool parseExpressionImpl(VariantValue *value); bool parseMatcherArgs(std::vector &args, MatcherCtor ctor, diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp --- a/mlir/lib/Query/Matcher/Parser.cpp +++ b/mlir/lib/Query/Matcher/Parser.cpp @@ -26,12 +26,17 @@ text = newText; } + // known identifiers. + static const char *const ID_Extract; + llvm::StringRef text; TokenKind kind = TokenKind::Eof; SourceRange range; VariantValue value; }; +const char *const Parser::TokenInfo::ID_Extract = "extract"; + class Parser::CodeTokenizer { public: // Constructor with matcherCode and error @@ -301,6 +306,34 @@ return parseMatcherExpressionImpl(nameToken, openToken, ctor, value); } +bool Parser::parseChainedExpression(std::string &argument) { + // Parse the parenthesized argument to .extract("foo") + const TokenInfo openToken = tokenizer->consumeNextToken(); + const TokenInfo argumentToken = tokenizer->consumeNextTokenIgnoreNewlines(); + const TokenInfo closeToken = tokenizer->consumeNextTokenIgnoreNewlines(); + + if (openToken.kind != TokenKind::OpenParen) { + error->addError(openToken.range, ErrorType::ParserChainedExprNoOpenParen); + return false; + } + + if (argumentToken.kind != TokenKind::Literal || + !argumentToken.value.isString()) { + error->addError(argumentToken.range, + ErrorType::ParserChainedExprInvalidArg); + return false; + } + + if (closeToken.kind != TokenKind::CloseParen) { + error->addError(closeToken.range, ErrorType::ParserChainedExprNoCloseParen); + return false; + } + + // If all checks passed, extract the argument and return true. + argument = argumentToken.value.getString(); + return true; +} + // Parse the arguments of a matcher bool Parser::parseMatcherArgs(std::vector &args, MatcherCtor ctor, const TokenInfo &nameToken, TokenInfo &endToken) { @@ -367,13 +400,34 @@ return false; } + std::string functionName; + if (tokenizer->peekNextToken().kind == TokenKind::Period) { + tokenizer->consumeNextToken(); + TokenInfo chainCallToken = tokenizer->consumeNextToken(); + if (chainCallToken.kind == TokenKind::CodeCompletion) { + addCompletion(chainCallToken, MatcherCompletion("extract(\"", "extract")); + return false; + } + + if (chainCallToken.kind != TokenKind::Ident || + chainCallToken.text != TokenInfo::ID_Extract) { + error->addError(chainCallToken.range, + ErrorType::ParserMalformedChainedExpr); + return false; + } + + if (chainCallToken.text == TokenInfo::ID_Extract && + !parseChainedExpression(functionName)) + return false; + } + if (!ctor) return false; // Merge the start and end infos. SourceRange matcherRange = nameToken.range; matcherRange.end = endToken.range.end; - VariantMatcher result = - sema->actOnMatcherExpression(*ctor, matcherRange, args, error); + VariantMatcher result = sema->actOnMatcherExpression( + *ctor, matcherRange, functionName, args, error); if (result.isNull()) return false; *value = result; @@ -473,9 +527,10 @@ } VariantMatcher Parser::RegistrySema::actOnMatcherExpression( - MatcherCtor ctor, SourceRange nameRange, llvm::ArrayRef args, - Diagnostics *error) { - return RegistryManager::constructMatcher(ctor, nameRange, args, error); + MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName, + llvm::ArrayRef args, Diagnostics *error) { + return RegistryManager::constructMatcher(ctor, nameRange, functionName, args, + error); } std::vector Parser::RegistrySema::getAcceptedCompletionTypes( diff --git a/mlir/lib/Query/Matcher/RegistryManager.h b/mlir/lib/Query/Matcher/RegistryManager.h --- a/mlir/lib/Query/Matcher/RegistryManager.h +++ b/mlir/lib/Query/Matcher/RegistryManager.h @@ -61,6 +61,7 @@ static VariantMatcher constructMatcher(MatcherCtor ctor, internal::SourceRange nameRange, + llvm::StringRef functionName, ArrayRef args, internal::Diagnostics *error); }; diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp --- a/mlir/lib/Query/Matcher/RegistryManager.cpp +++ b/mlir/lib/Query/Matcher/RegistryManager.cpp @@ -132,8 +132,22 @@ VariantMatcher RegistryManager::constructMatcher( MatcherCtor ctor, internal::SourceRange nameRange, - llvm::ArrayRef args, internal::Diagnostics *error) { - return ctor->create(nameRange, args, error); + llvm::StringRef functionName, llvm::ArrayRef args, + internal::Diagnostics *error) { + + VariantMatcher out = ctor->create(nameRange, args, error); + if (functionName.empty() || out.isNull()) + return out; + + std::optional result = out.getDynMatcher(); + + if (result.has_value()) { + result->setFunctionName(functionName); + return VariantMatcher::SingleMatcher(*result); + } + + // error->addError(NameRange, error->ET_RegistryNotBindable); + return out; } } // namespace mlir::query::matcher diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp --- a/mlir/lib/Query/Query.cpp +++ b/mlir/lib/Query/Query.cpp @@ -8,6 +8,8 @@ #include "mlir/Query/Query.h" #include "QueryParser.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/IRMapping.h" #include "mlir/Query/Matcher/MatchFinder.h" #include "mlir/Query/QuerySession.h" #include "mlir/Support/LogicalResult.h" @@ -34,6 +36,71 @@ "\"" + binding + "\" binds here"); } +static Operation *extractFunction(std::vector &ops, + MLIRContext *context, + llvm::StringRef functionName) { + context->loadDialect(); + OpBuilder builder(context); + + // Collect data for function creation + std::vector slice; + std::vector values; + std::vector outputTypes; + + for (auto *op : ops) { + + if (!isa(op)) + slice.push_back(op); + + outputTypes.insert(outputTypes.end(), op->getResults().getTypes().begin(), + op->getResults().getTypes().end()); + + values.insert(values.end(), op->getOperands().begin(), + op->getOperands().end()); + } + + auto loc = builder.getUnknownLoc(); + + // Create the function + FunctionType funcType = + builder.getFunctionType(ValueRange(values), outputTypes); + func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType); + + builder.setInsertionPointToEnd(funcOp.addEntryBlock()); + builder.setInsertionPointToEnd(&funcOp.getBody().front()); + + // Map original values to function arguments + IRMapping mapper; + for (const auto &arg : llvm::enumerate(values)) + mapper.map(arg.value(), funcOp.getArgument(arg.index())); + + // Clone operations and build function body + std::vector clonedOps; + for (Operation *slicedOp : slice) + clonedOps.push_back(builder.clone(*slicedOp, mapper)); + + // Remove unused function arguments + size_t currentIndex = 0; + while (currentIndex < funcOp.getNumArguments()) { + if (funcOp.getArgument(currentIndex).getUses().empty()) { + funcOp.eraseArgument(currentIndex); + } else { + currentIndex++; + } + } + + // Collect cloned values + std::vector clonedVals; + for (Operation *clonedOp : clonedOps) + clonedVals.insert(clonedVals.end(), clonedOp->result_begin(), + clonedOp->result_end()); + + // Add return operation + builder.create(loc, clonedVals); + + return funcOp; +} + Query::~Query() = default; mlir::LogicalResult InvalidQuery::run(llvm::raw_ostream &os, @@ -65,9 +132,19 @@ mlir::LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const { + Operation *rootOp = qs.getRootOp(); int matchCount = 0; std::vector matches = - matcher::MatchFinder().getMatches(qs.getRootOp(), matcher); + matcher::MatchFinder().getMatches(rootOp, matcher); + + if (matcher.hasFunctionName()) { + auto functionName = matcher.getFunctionName(); + Operation *function = + extractFunction(matches, rootOp->getContext(), functionName); + os << "\n" << *function << "\n\n"; + return mlir::success(); + } + os << "\n"; for (Operation *op : matches) { os << "Match #" << ++matchCount << ":\n\n"; diff --git a/mlir/test/mlir-query/function-extraction.mlir b/mlir/test/mlir-query/function-extraction.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-query/function-extraction.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-query %s -c "m hasOpName(\"arith.mulf\").extract(\"testmul\")" | FileCheck %s + +// CHECK: func.func @testmul({{.*}}) -> (f32, f32, f32) { +// CHECK: %[[MUL0:.*]] = arith.mulf {{.*}} : f32 +// CHECK: %[[MUL1:.*]] = arith.mulf {{.*}}, %[[MUL0]] : f32 +// CHECK: %[[MUL2:.*]] = arith.mulf {{.*}} : f32 +// CHECK-NEXT: return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32 + +func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 { + %sum0 = arith.addf %a, %b : f32 + %sub0 = arith.subf %sum0, %c : f32 + %mul0 = arith.mulf %a, %sub0 : f32 + %sum1 = arith.addf %b, %c : f32 + %mul1 = arith.mulf %sum1, %mul0 : f32 + %sub2 = arith.subf %mul1, %a : f32 + %sum2 = arith.addf %mul1, %b : f32 + %mul2 = arith.mulf %sub2, %sum2 : f32 + return %mul2 : f32 +}