diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -371,6 +371,29 @@ } //===----------------------------------------------------------------------===// +// pdl_interp::ContinueOp +//===----------------------------------------------------------------------===// + +def PDLInterp_ContinueOp + : PDLInterp_Op<"continue", [NoSideEffect, HasParent<"ForEachOp">, + Terminator]> { + let summary = "Breaks the current iteration"; + let description = [{ + `pdl_interp.continue` operation breaks the current iteration within the + `pdl_interp.foreach` region and continues with the next iteration from + the beginning of the region. + + Example: + + ```mlir + pdl_interp.continue + ``` + }]; + + let assemblyFormat = "attr-dict"; +} + +//===----------------------------------------------------------------------===// // pdl_interp::CreateAttributeOp //===----------------------------------------------------------------------===// @@ -534,6 +557,48 @@ } //===----------------------------------------------------------------------===// +// pdl_interp::ForEachOp +//===----------------------------------------------------------------------===// + +def PDLInterp_ForEachOp + : PDLInterp_Op<"foreach", [Terminator]> { + let summary = "Iterates over a range of values or ranges"; + let description = [{ + `pdl_interp.foreach` iteratively selects an element from a range of values + and executes the region until pdl.continue is reached. + + In the bytecode interpreter, this operation is implemented by looping over + the values and, for each selection, running the bytecode until we reach + pdl.continue. This may result in multiple matches being reported. Note + that the input range is mutated (popped from). + + Example: + + ```mlir + pdl_interp.foreach %op : !pdl.operation in %ops { + pdl_interp.continue + } -> ^next + ``` + }]; + + let arguments = (ins PDL_RangeOf:$values); + let regions = (region AnyRegion:$region); + let successors = (successor AnySuccessor:$successor); + + let builders = [ + OpBuilder<(ins "Value":$range, "Block *":$successor, "bool":$initLoop)> + ]; + + let extraClassDeclaration = [{ + /// Returns the loop variable. + BlockArgument getLoopVariable() { return region().getArgument(0); } + }]; + let parser = [{ return ::parseForEachOp(parser, result); }]; + let printer = [{ return ::print(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; +} + +//===----------------------------------------------------------------------===// // pdl_interp::GetAttributeOp //===----------------------------------------------------------------------===// @@ -751,6 +816,40 @@ } //===----------------------------------------------------------------------===// +// pdl_interp::GetUsersOp +//===----------------------------------------------------------------------===// + +def PDLInterp_GetUsersOp + : PDLInterp_Op<"get_users", [NoSideEffect]> { + let summary = "Get the users of a `Value`"; + let description = [{ + `pdl_interp.get_users` extracts the users that accept this value. In the + case of a range, the users of the first value are returned. + + Example: + + ```mlir + // Get all the users of a single value. + %ops = pdl_interp.get_users of %value : !pdl.value + + // Get all the users of the first value in a range. + %ops = pdl_interp.get_users of %values : !pdl.range + ``` + }]; + + let arguments = (ins PDL_InstOrRangeOf:$value); + let results = (outs PDL_RangeOf:$operations); + let assemblyFormat = "`of` $value `:` type($value) attr-dict"; + let builders = [ + OpBuilder<(ins "Value":$value), [{ + build($_builder, $_state, + pdl::RangeType::get($_builder.getType()), + value); + }]>, + ]; +} + +//===----------------------------------------------------------------------===// // pdl_interp::GetValueTypeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp --- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp +++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp @@ -66,6 +66,85 @@ } //===----------------------------------------------------------------------===// +// pdl_interp::ForEachOp +//===----------------------------------------------------------------------===// + +void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, + Value range, Block *successor, bool initLoop) { + build(builder, state, range, successor); + if (initLoop) { + // Create the block and the loop variable. + auto range_type = range.getType().cast(); + state.regions.front()->emplaceBlock(); + state.regions.front()->addArgument(range_type.getElementType()); + } +} + +static ParseResult parseForEachOp(OpAsmParser &parser, OperationState &result) { + // Parse the loop variable followed by type. + OpAsmParser::OperandType loopVariable; + Type loopVariableType; + if (parser.parseRegionArgument(loopVariable) || + parser.parseColonType(loopVariableType)) + return failure(); + + // Parse the "in" keyword. + if (parser.parseKeyword("in", " after loop variable")) + return failure(); + + // Parse the operand (value range). + OpAsmParser::OperandType operandInfo; + if (parser.parseOperand(operandInfo)) + return failure(); + + // Resolve the operand. + Type rangeType = pdl::RangeType::get(loopVariableType); + if (parser.resolveOperand(operandInfo, rangeType, result.operands)) + return failure(); + + // Parse the body region. + Region *body = result.addRegion(); + if (parser.parseRegion(*body, {loopVariable}, {loopVariableType})) + return failure(); + + // Parse the attribute dictionary. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + // Parse the successor. + Block *successor; + if (parser.parseArrow() || parser.parseSuccessor(successor)) + return failure(); + result.addSuccessors(successor); + + return success(); +} + +static void print(OpAsmPrinter &p, ForEachOp op) { + BlockArgument arg = op.getLoopVariable(); + p << ' ' << arg << " : " << arg.getType() << " in " << op.values(); + p.printRegion(op.region(), /*printEntryBlockArgs=*/false); + p.printOptionalAttrDict(op->getAttrs()); + p << " -> "; + p.printSuccessor(op.successor()); +} + +static LogicalResult verify(ForEachOp op) { + // Verify that the operation has exactly one argument. + if (op.region().getNumArguments() != 1) + return op.emitOpError("requires exactly one argument"); + + // Verify that the loop variable and the operand (value range) + // have compatible types. + BlockArgument arg = op.getLoopVariable(); + Type rangeType = pdl::RangeType::get(arg.getType()); + if (rangeType != op.values().getType()) + return op.emitOpError("operand must be a range of loop variable type"); + + return success(); +} + +//===----------------------------------------------------------------------===// // pdl_interp::GetValueTypeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/PDLInterp/ops.mlir b/mlir/test/Dialect/PDLInterp/ops.mlir --- a/mlir/test/Dialect/PDLInterp/ops.mlir +++ b/mlir/test/Dialect/PDLInterp/ops.mlir @@ -23,3 +23,28 @@ pdl_interp.finalize } + +// ----- + +func @foreach(%ops: !pdl.range) { + // iterate over a range of operations + pdl_interp.foreach %op : !pdl.operation in %ops { + %val = pdl_interp.get_result 0 of %op + pdl_interp.continue + } -> ^end + + ^end: + pdl_interp.finalize +} + +// ----- + +func @users(%value: !pdl.value, %values: !pdl.range) { + // all the users of a single value + %ops1 = pdl_interp.get_users of %value : !pdl.value + + // all the users of the first value in a range + %ops2 = pdl_interp.get_users of %values : !pdl.range + + pdl_interp.finalize +}