diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -835,28 +835,7 @@ pdl::OperationOp op, function_ref mapRewriteValue, SmallVectorImpl &types, DenseMap &rewriteValues, bool &hasInferredResultTypes) { - // Look for an operation that was replaced by `op`. The result types will be - // inferred from the results that were replaced. Block *rewriterBlock = op->getBlock(); - for (OpOperand &use : op.op().getUses()) { - // Check that the use corresponds to a ReplaceOp and that it is the - // replacement value, not the operation being replaced. - pdl::ReplaceOp replOpUser = dyn_cast(use.getOwner()); - if (!replOpUser || use.getOperandNumber() == 0) - continue; - // Make sure the replaced operation was defined before this one. - Value replOpVal = replOpUser.operation(); - Operation *replacedOp = replOpVal.getDefiningOp(); - if (replacedOp->getBlock() == rewriterBlock && - !replacedOp->isBeforeInBlock(op)) - continue; - - Value replacedOpResults = builder.create( - replacedOp->getLoc(), mapRewriteValue(replOpVal)); - types.push_back(builder.create( - replacedOp->getLoc(), replacedOpResults)); - return; - } // Try to handle resolution for each of the result types individually. This is // preferred over type inferrence because it will allow for us to use existing @@ -895,6 +874,31 @@ return; } + // Look for an operation that was replaced by `op`. The result types will be + // inferred from the results that were replaced. + for (OpOperand &use : op.op().getUses()) { + // Check that the use corresponds to a ReplaceOp and that it is the + // replacement value, not the operation being replaced. + pdl::ReplaceOp replOpUser = dyn_cast(use.getOwner()); + if (!replOpUser || use.getOperandNumber() == 0) + continue; + // Make sure the replaced operation was defined before this one. PDL + // rewrites only have single block regions, so if the op isn't in the + // rewriter block (i.e. the current block of the operation) we already know + // it dominates (i.e. it's in the matcher). + Value replOpVal = replOpUser.operation(); + Operation *replacedOp = replOpVal.getDefiningOp(); + if (replacedOp->getBlock() == rewriterBlock && + !replacedOp->isBeforeInBlock(op)) + continue; + + Value replacedOpResults = builder.create( + replacedOp->getLoc(), mapRewriteValue(replOpVal)); + types.push_back(builder.create( + replacedOp->getLoc(), replacedOpResults)); + return; + } + // If the types could not be inferred from any context and there weren't any // explicit result types, assume the user actually meant for the operation to // have no results. diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir @@ -86,7 +86,7 @@ %root = operation "foo.op" -> (%rootType, %rootType1 : !pdl.type, !pdl.type) rewrite %root { %newType1 = type - %newOp = operation "foo.op" -> (%rootType, %newType1 : !pdl.type, !pdl.type) + %newOp = operation "foo.op" replace %root with %newOp } }