diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -907,6 +907,13 @@ virtual Operation *parseGenericOperation(Block *insertBlock, Block::iterator insertPt) = 0; + /// Parse an operation name, in the custom form, and return a tuple containing + /// name of the operation and pointers, possibly nullptr, to the corresponding + /// dialect & abstract-operation. + using ParseOpNameResult = + std::tuple; + virtual ParseOpNameResult parseCustomOperationName() = 0; + //===--------------------------------------------------------------------===// // Operand Parsing //===--------------------------------------------------------------------===// @@ -918,6 +925,12 @@ unsigned number; // Number, e.g. 12 for an operand like %xyz#12 }; + /// Parse an operation instance, in its generic form, after its operands are + /// already parsed. + virtual ParseResult + parseGenericOperationAfterOperands(OperationState &result, + ArrayRef operandInfos) = 0; + /// Parse a single operand. virtual ParseResult parseOperand(OperandType &result) = 0; diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -310,6 +310,11 @@ /// Parse an operation instance that is in the generic form. Operation *parseGenericOperation(); + /// Parse an operation instance, in its generic form, after its operands are + /// already parsed. + ParseResult parseGenericOperationAfterOperands(OperationState &result, + ArrayRef useInfo); + /// Parse an operation instance that is in the generic form and insert it at /// the provided insertion point. Operation *parseGenericOperation(Block *insertBlock, @@ -335,6 +340,13 @@ /// resultInfo specifies information about the "%name =" specifiers. Operation *parseCustomOperation(ArrayRef resultIDs); + /// Parse an operation name, in the custom form, and return a tuple containing + /// name of the operation and pointers, possibly nullptr, to the corresponding + /// dialect & abstract-operation. + using ParseOpNameResult = + std::tuple; + ParseOpNameResult parseCustomOperationName(); + //===--------------------------------------------------------------------===// // Region Parsing //===--------------------------------------------------------------------===// @@ -972,6 +984,70 @@ }; } // namespace +ParseResult OperationParser::parseGenericOperationAfterOperands( + OperationState &result, ArrayRef useInfo) { + // Parse the successor list. + if (getToken().is(Token::l_square)) { + // Check if the operation is a known terminator. + const AbstractOperation *abstractOp = result.name.getAbstractOperation(); + if (abstractOp && !abstractOp->hasTrait()) + return emitError("successors in non-terminator"); + + SmallVector successors; + if (parseSuccessors(successors)) + return failure(); + result.addSuccessors(successors); + } + + // Parse the region list. + if (consumeIf(Token::l_paren)) { + do { + // Create temporary regions with the top level region as parent. + result.regions.emplace_back(new Region(topLevelOp)); + if (parseRegion(*result.regions.back(), /*entryArguments=*/{})) + return failure(); + } while (consumeIf(Token::comma)); + if (parseToken(Token::r_paren, "expected ')' to end region list")) + return failure(); + } + + if (getToken().is(Token::l_brace)) { + if (parseAttributeDict(result.attributes)) + return failure(); + } + + if (parseToken(Token::colon, "expected ':' followed by operation type")) + return failure(); + + auto typeLoc = getToken().getLoc(); + auto type = parseType(); + if (!type) + return failure(); + auto fnType = type.dyn_cast(); + if (!fnType) + return emitError(typeLoc, "expected function type"); + + result.addTypes(fnType.getResults()); + + // Check that we have the right number of types for the operands. + auto operandTypes = fnType.getInputs(); + if (operandTypes.size() != useInfo.size()) { + auto plural = "s"[useInfo.size() == 1]; + return emitError(typeLoc, "expected ") + << useInfo.size() << " operand type" << plural << " but had " + << operandTypes.size(); + } + + // Resolve all of the operands. + for (unsigned i = 0, e = useInfo.size(); i != e; ++i) { + result.operands.push_back(resolveSSAUse(useInfo[i], operandTypes[i])); + if (!result.operands.back()) + return failure(); + } + + return success(); +} + Operation *OperationParser::parseGenericOperation() { // Get location information for the operation. auto srcLocation = getEncodedSourceLocation(getToken().getLoc()); @@ -985,6 +1061,7 @@ consumeToken(Token::string); OperationState result(srcLocation, name); + CleanupOpStateRegions guard{result}; // Lazy load dialects in the context as needed. if (!result.name.getAbstractOperation()) { @@ -1016,67 +1093,9 @@ return nullptr; } - // Parse the successor list. - if (getToken().is(Token::l_square)) { - // Check if the operation is a known terminator. - const AbstractOperation *abstractOp = result.name.getAbstractOperation(); - if (abstractOp && !abstractOp->hasTrait()) - return emitError("successors in non-terminator"), nullptr; - - SmallVector successors; - if (parseSuccessors(successors)) - return nullptr; - result.addSuccessors(successors); - } - - // Parse the region list. - CleanupOpStateRegions guard{result}; - if (consumeIf(Token::l_paren)) { - do { - // Create temporary regions with the top level region as parent. - result.regions.emplace_back(new Region(topLevelOp)); - if (parseRegion(*result.regions.back(), /*entryArguments=*/{})) - return nullptr; - } while (consumeIf(Token::comma)); - if (parseToken(Token::r_paren, "expected ')' to end region list")) - return nullptr; - } - - if (getToken().is(Token::l_brace)) { - if (parseAttributeDict(result.attributes)) - return nullptr; - } - - if (parseToken(Token::colon, "expected ':' followed by operation type")) + if (parseGenericOperationAfterOperands(result, operandInfos)) return nullptr; - auto typeLoc = getToken().getLoc(); - auto type = parseType(); - if (!type) - return nullptr; - auto fnType = type.dyn_cast(); - if (!fnType) - return (emitError(typeLoc, "expected function type"), nullptr); - - result.addTypes(fnType.getResults()); - - // Check that we have the right number of types for the operands. - auto operandTypes = fnType.getInputs(); - if (operandTypes.size() != operandInfos.size()) { - auto plural = "s"[operandInfos.size() == 1]; - return (emitError(typeLoc, "expected ") - << operandInfos.size() << " operand type" << plural - << " but had " << operandTypes.size(), - nullptr); - } - - // Resolve all of the operands. - for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) { - result.operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i])); - if (!result.operands.back()) - return nullptr; - } - // Create the operation and try to parse a location for it. Operation *op = opBuilder.createOperation(result); if (parseTrailingLocationSpecifier(op)) @@ -1137,6 +1156,24 @@ return parser.parseGenericOperation(insertBlock, insertPt); } + ParseOpNameResult parseCustomOperationName() final { + return parser.parseCustomOperationName(); + } + + ParseResult + parseGenericOperationAfterOperands(OperationState &result, + ArrayRef operandInfos) final { + + SmallVector useInfo; + for (auto &operandInfo : operandInfos) + useInfo.push_back({ + operandInfo.name, + operandInfo.number, + operandInfo.location, + }); + + return parser.parseGenericOperationAfterOperands(result, useInfo); + } //===--------------------------------------------------------------------===// // Utilities //===--------------------------------------------------------------------===// @@ -1510,9 +1547,7 @@ }; } // end anonymous namespace. -Operation * -OperationParser::parseCustomOperation(ArrayRef resultIDs) { - llvm::SMLoc opLoc = getToken().getLoc(); +OperationParser::ParseOpNameResult OperationParser::parseCustomOperationName() { std::string opName = getTokenSpelling().str(); auto *opDefinition = AbstractOperation::lookup(opName, getContext()); StringRef defaultDialect = getState().defaultDialectStack.back(); @@ -1546,13 +1581,27 @@ } } + consumeToken(); + + return std::tuple(opName, dialect, opDefinition); +} + +Operation * +OperationParser::parseCustomOperation(ArrayRef resultIDs) { + llvm::SMLoc opLoc = getToken().getLoc(); + + auto parseOpNameResult = parseCustomOperationName(); + std::string opName = std::get<0>(parseOpNameResult); + Dialect *dialect = std::get<1>(parseOpNameResult); + const AbstractOperation *opDefinition = std::get<2>(parseOpNameResult); + // This is the actual hook for the custom op parsing, usually implemented by // the op itself (`Op::parse()`). We retrieve it either from the // AbstractOperation or from the Dialect. function_ref parseAssemblyFn; bool isIsolatedFromAbove = false; - defaultDialect = ""; + StringRef defaultDialect = ""; if (opDefinition) { parseAssemblyFn = opDefinition->getParseAssemblyFn(); isIsolatedFromAbove = @@ -1574,8 +1623,6 @@ auto restoreDefaultDialect = llvm::make_scope_exit( [&]() { getState().defaultDialectStack.pop_back(); }); - consumeToken(); - // If the custom op parser crashes, produce some indication to help // debugging. llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'", diff --git a/mlir/test/IR/pretty_printed_region_op_op.mlir b/mlir/test/IR/pretty_printed_region_op_op.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/pretty_printed_region_op_op.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file %s | FileCheck %s --check-prefixes=CHECK-CUSTOM,CHECK +// RUN: mlir-opt -allow-unregistered-dialect -mlir-print-op-generic -split-input-file %s | FileCheck %s --check-prefixes=CHECK,CHECK-GENERIC + +// ----- + +func @pretty_printed_region_op(%arg0 : f32, %arg1 : f32) -> (f32) { +// CHECK-CUSTOM: test.pretty_printed_region %arg1, %arg0 start special.op end : (f32, f32) -> f32 +// CHECK-GENERIC: "test.pretty_printed_region"(%arg1, %arg0) +// CHECK-GENERIC: ^bb0(%arg[[x:[0-9]+]]: f32, %arg[[y:[0-9]+]]: f32 +// CHECK-GENERIC: %[[RES:.*]] = "special.op"(%arg[[x]], %arg[[y]]) : (f32, f32) -> f32 +// CHECK-GENERIC: "test.return"(%[[RES]]) : (f32) -> () +// CHECK-GENERIC: : (f32, f32) -> f32 + + %res = test.pretty_printed_region %arg1, %arg0 start special.op end : (f32, f32) -> (f32) loc("some_NameLoc") + return %res : f32 +} + +// ----- + +func @pretty_printed_region_op(%arg0 : f32, %arg1 : f32) -> (f32) { +// CHECK-CUSTOM: test.pretty_printed_region %arg1, %arg0 +// CHECK-GENERIC: "test.pretty_printed_region"(%arg1, %arg0) +// CHECK: ^bb0(%arg[[x:[0-9]+]]: f32, %arg[[y:[0-9]+]]: f32): +// CHECK: %[[RES:.*]] = "non.special.op"(%arg[[x]], %arg[[y]]) : (f32, f32) -> f32 +// CHECK: "test.return"(%[[RES]]) : (f32) -> () +// CHECK: : (f32, f32) -> f32 + + %0 = test.pretty_printed_region %arg1, %arg0 ( { + ^bb0(%arg2: f32, %arg3: f32): + %1 = "non.special.op"(%arg2, %arg3) : (f32, f32) -> f32 + "test.return"(%1) : (f32) -> () + }) : (f32, f32) -> f32 + return %0 : f32 +} + diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -720,6 +720,103 @@ p.printGenericOp(&op.getRegion().front().front()); } +//===----------------------------------------------------------------------===// +// Test PrettyPrintedRegionOp - exercising the following parser APIs +// parseGenericOperationAfterOperands +// parseCustomOperationName +//===----------------------------------------------------------------------===// + +static ParseResult parsePrettyPrintedRegionOp(OpAsmParser &parser, + OperationState &result) { + + llvm::SMLoc loc = parser.getCurrentLocation(); + Location currLocation = parser.getEncodedSourceLoc(loc); + + // Parse the operands. + SmallVector operands; + if (parser.parseOperandList(operands)) + return failure(); + + // Check if we are parsing the pretty-printed version + // test.pretty_printed_region start end : + // Else fallback to parsing the "non pretty-printed" version. + if (!succeeded(parser.parseOptionalKeyword("start"))) + return parser.parseGenericOperationAfterOperands(result, operands); + + auto parseOpNameResult = parser.parseCustomOperationName(); + StringRef innerOpName = std::get<0>(parseOpNameResult); + + FunctionType opFntype; + Optional explicitLoc; + if (parser.parseKeyword("end") || parser.parseColon() || + parser.parseType(opFntype) || + parser.parseOptionalLocationSpecifier(explicitLoc)) + return failure(); + + // If location of the op is explicitly provided, then use it; Else use + // the parser's current location. + Location opLoc = explicitLoc.getValueOr(currLocation); + + // Derive the SSA-values for op's operands. + if (parser.resolveOperands(operands, opFntype.getInputs(), loc, + result.operands)) + return failure(); + + // Add a region for op. + Region ®ion = *result.addRegion(); + + // Create a basic-block inside op's region. + Block &block = region.emplaceBlock(); + + // Create and insert an "inner-op" operation in the block. + // Just for testing purposes, we can assume that inner op is a binary op with + // result and operand types all same as the test-op's first operand. + auto innerOpType = opFntype.getInput(0); + auto lhs = block.addArgument(innerOpType, opLoc); + auto rhs = block.addArgument(innerOpType, opLoc); + + OpBuilder builder(parser.getBuilder().getContext()); + builder.setInsertionPointToStart(&block); + + OperationState innerOpState(opLoc, innerOpName); + innerOpState.operands.push_back(lhs); + innerOpState.operands.push_back(rhs); + innerOpState.addTypes(innerOpType); + + Operation *innerOp = builder.createOperation(innerOpState); + + // Insert a return statement in the block returning the inner-op's result. + builder.create(innerOp->getLoc(), innerOp->getResults()); + + // Populate the op operation-state with result-type and location. + result.addTypes(opFntype.getResults()); + result.location = innerOp->getLoc(); + + return success(); +} + +static void print(OpAsmPrinter &p, PrettyPrintedRegionOp op) { + p << ' '; + p.printOperands(op.getOperands()); + + Operation &innerOp = op.getRegion().front().front(); + // Assuming that region has a single non-terminator inner-op, if the inner-op + // meets some criteria (which in this case is a simple one based on the name + // of inner-op), then we can print the entire region in a succinct way. + // Here we assume that the prototype of "special.op" can be trivially derived + // while parsing it back. + if (innerOp.getName().getStringRef().equals("special.op")) { + p << " start special.op end"; + } else { + p << " ("; + p.printRegion(op.getRegion()); + p << ")"; + } + + p << " : "; + p.printFunctionalType(op); +} + //===----------------------------------------------------------------------===// // Test PolyForOp - parse list of region arguments. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1630,6 +1630,25 @@ let printer = [{ return ::print(p, *this); }]; } +def PrettyPrintedRegionOp : TEST_Op<"pretty_printed_region", + [SingleBlockImplicitTerminator<"TestReturnOp">]> { + let summary = "pretty_printed_region operation"; + let description = [{ + Test-op can be printed either in a "pretty" or "non-pretty" way based on + some criteria. The custom parser parsers both the versions while testing + APIs: parseCustomOperationName & parseGenericOperationAfterOperands. + }]; + let arguments = (ins + AnyType:$input1, + AnyType:$input2 + ); + + let results = (outs AnyType); + let regions = (region SizedRegion<1>:$region); + let parser = [{ return ::parse$cppClass(parser, result); }]; + let printer = [{ return ::print(p, *this); }]; +} + def PolyForOp : TEST_Op<"polyfor"> { let summary = "polyfor operation";