diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td --- a/mlir/include/mlir/IR/OpAsmInterface.td +++ b/mlir/include/mlir/IR/OpAsmInterface.td @@ -66,6 +66,33 @@ ), "", "return;" >, + InterfaceMethod<[{ + Get the value names that a region's arguments should shadow. + + For regions under an `IsolatedFromAbove` operation, the names of + `Values` defined above the region may be re-used for region argument + names. For example, + ```c++ + setShadowingFn(region.getArgument(0), getOperand(0)); + ``` + would print as: + ```mlir + %foo = ... + my.isolated_from_above_op(%foo) ({ + ^bb0(%foo): // the outer %foo is shadowed + ... + }) + ``` + This should be combined with a custom printer/parser which elides the + shadowing region arguments. + }], + "void", "getAsmRegionArgumentShadowing", + (ins + "::mlir::Region&":$region, + "::mlir::OpAsmSetRegionArgumentShadowingFn":$setShadowingFn + ), + "", "return;" + >, InterfaceMethod<[{ Get the name to use for a given block inside a region attached to this operation. diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -427,12 +427,6 @@ bool printBlockTerminators = true, bool printEmptyBlock = false) = 0; - /// Renumber the arguments for the specified region to the same names as the - /// SSA values in namesToUse. This may only be used for IsolatedFromAbove - /// operations. If any entry in namesToUse is null, the corresponding - /// argument name is left alone. - virtual void shadowRegionArgs(Region ®ion, ValueRange namesToUse) = 0; - /// Prints an affine map of SSA ids, where SSA id names are used in place /// of dims/symbols. /// Operand values must come from single-result sources, and be valid @@ -1566,9 +1560,14 @@ //===--------------------------------------------------------------------===// /// A functor used to set the name of the start of a result group of an -/// operation. See 'getAsmResultNames' below for more details. +/// operation. See `OpAsmOpInterface::getAsmResultNames` for more details. using OpAsmSetValueNameFn = function_ref; +/// A functor used to set the name shadowing of a region argument. See +/// `OpAsmOpInterface::getRegionArgumentShadowing` for more details. +using OpAsmSetRegionArgumentShadowingFn = + function_ref; + /// A functor used to set the name of blocks in regions directly nested under /// an operation. using OpAsmSetBlockNameFn = function_ref; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -766,7 +766,6 @@ void printSymbolName(StringRef) override {} void printSuccessor(Block *) override {} void printSuccessorAndUseList(Block *, ValueRange) override {} - void shadowRegionArgs(Region &, ValueRange) override {} /// The printer flags to use when determining potential aliases. const OpPrintingFlags &printerFlags; @@ -1211,11 +1210,6 @@ /// Get the info for the given block. BlockInfo getBlockInfo(Block *block); - /// Renumber the arguments for the specified region to the same names as the - /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for - /// details. - void shadowRegionArgs(Region ®ion, ValueRange namesToUse); - private: /// Number the SSA values within the given IR unit. void numberValuesInRegion(Region ®ion); @@ -1386,35 +1380,6 @@ return it != blockNames.end() ? it->second : invalidBlock; } -void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { - assert(!region.empty() && "cannot shadow arguments of an empty region"); - assert(region.getNumArguments() == namesToUse.size() && - "incorrect number of names passed in"); - assert(region.getParentOp()->hasTrait() && - "only KnownIsolatedFromAbove ops can shadow names"); - - SmallVector nameStr; - for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { - auto nameToUse = namesToUse[i]; - if (nameToUse == nullptr) - continue; - auto nameToReplace = region.getArgument(i); - - nameStr.clear(); - llvm::raw_svector_ostream nameStream(nameStr); - printValueID(nameToUse, /*printResultNo=*/true, nameStream); - - // Entry block arguments should already have a pretty "arg" name. - assert(valueIDs[nameToReplace] == NameSentinel); - - // Use the name without the leading %. - auto name = StringRef(nameStream.str()).drop_front(); - - // Overwrite the name. - valueNames[nameToReplace] = name.copy(usedNameAllocator); - } -} - void SSANameState::numberValuesInRegion(Region ®ion) { auto setBlockArgNameFn = [&](Value arg, StringRef name) { assert(!valueIDs.count(arg) && "arg numbered multiple times"); @@ -1423,10 +1388,36 @@ setValueName(arg, name); }; + auto setRegionArgShadowingFn = [&](BlockArgument arg, Value valToShadow) { + assert(!valueIDs.count(arg) && "arg numbered multiple times"); + assert(arg.getOwner() == ®ion.front() && + "arg is not an argument of the current region"); + assert(region.getParentOp()->hasTrait() && + "only known IsolatedFromAbove ops can shadow names"); + assert(valToShadow.getParentRegion()->isProperAncestor(®ion) && + "shadowed value must be defined outside the region"); + + SmallVector nameStr; + llvm::raw_svector_ostream nameStream(nameStr); + printValueID(valToShadow, /*printResultNo=*/true, nameStream); + + // Use the name without the leading %. + auto name = StringRef(nameStream.str()).drop_front(); + + // NOTE: We bypass `setValueName` here because the name is likely not + // unique, and may contain invalid characters (e.g. if `valToShadow` + // is part of a result group, we may get a name like 'result#1'). + valueIDs[arg] = NameSentinel; + valueNames[arg] = name.copy(usedNameAllocator); + }; + if (!printerFlags.shouldPrintGenericOpForm()) { if (Operation *op = region.getParentOp()) { - if (auto asmInterface = dyn_cast(op)) + if (auto asmInterface = dyn_cast(op)) { asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn); + asmInterface.getAsmRegionArgumentShadowing(region, + setRegionArgShadowingFn); + } } } @@ -2999,14 +2990,6 @@ void printRegion(Region ®ion, bool printEntryBlockArgs, bool printBlockTerminators, bool printEmptyBlock) override; - /// Renumber the arguments for the specified region to the same names as the - /// SSA values in namesToUse. This may only be used for IsolatedFromAbove - /// operations. If any entry in namesToUse is null, the corresponding - /// argument name is left alone. - void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override { - state.getSSANameState().shadowRegionArgs(region, namesToUse); - } - /// Print the given affine map with the symbol and dimension operands printed /// inline with the map. void printAffineMapOfSSAIds(AffineMapAttr mapAttr, @@ -3665,10 +3648,6 @@ } void Value::printAsOperand(raw_ostream &os, AsmState &state) { - // TODO: This doesn't necessarily capture all potential cases. - // Currently, region arguments can be shadowed when printing the main - // operation. If the IR hasn't been printed, this will produce the old SSA - // name and not the shadowed name. state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true, os); } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -984,11 +984,15 @@ void IsolatedRegionOp::print(OpAsmPrinter &p) { p << ' '; p.printOperand(getOperand()); - p.shadowRegionArgs(getRegion(), getOperand()); p << ' '; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); } +void IsolatedRegionOp::getAsmRegionArgumentShadowing( + Region ®ion, OpAsmSetRegionArgumentShadowingFn setShadowingFn) { + setShadowingFn(region.getArgument(0), getOperand()); +} + //===----------------------------------------------------------------------===// // Test SSACFGRegionOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1929,7 +1929,9 @@ //===----------------------------------------------------------------------===// // Test region argument list parsing. -def IsolatedRegionOp : TEST_Op<"isolated_region", [IsolatedFromAbove]> { +def IsolatedRegionOp : TEST_Op<"isolated_region", [ + IsolatedFromAbove, + DeclareOpInterfaceMethods]> { let summary = "isolated region operation"; let description = [{ Test op with an isolated region, to test passthrough region arguments. Each