diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -207,16 +207,17 @@ if (isa(resultTypeOp)) continue; - // If the type operation was defined in the matcher and constrains the - // result of an input operation, it can be used. - auto constrainsInputOp = [rewriterBlock](Operation *user) { - return user->getBlock() != rewriterBlock && isa(user); + // If the type operation was defined in the matcher and constrains an + // operand or the result of an input operation, it can be used. + auto constrainsInput = [rewriterBlock](Operation *user) { + return user->getBlock() != rewriterBlock && + isa(user); }; if (TypeOp typeOp = dyn_cast(resultTypeOp)) { - if (typeOp.type() || llvm::any_of(typeOp->getUsers(), constrainsInputOp)) + if (typeOp.type() || llvm::any_of(typeOp->getUsers(), constrainsInput)) continue; } else if (TypesOp typeOp = dyn_cast(resultTypeOp)) { - if (typeOp.types() || llvm::any_of(typeOp->getUsers(), constrainsInputOp)) + if (typeOp.types() || llvm::any_of(typeOp->getUsers(), constrainsInput)) continue; } diff --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir --- a/mlir/test/Dialect/PDL/ops.mlir +++ b/mlir/test/Dialect/PDL/ops.mlir @@ -88,7 +88,7 @@ // ----- // Check that the result type of an operation within a rewrite can be inferred -// from types used within the match block. +// from the result types of an operation within the match block. pdl.pattern @infer_type_from_type_used_in_match : benefit(1) { %type1 = pdl.type : i32 %type2 = pdl.type @@ -101,7 +101,7 @@ // ----- // Check that the result type of an operation within a rewrite can be inferred -// from types used within the match block. +// from the result types of an operation within the match block. pdl.pattern @infer_type_from_type_used_in_match : benefit(1) { %types = pdl.types %root = pdl.operation -> (%types : !pdl.range) @@ -113,6 +113,34 @@ // ----- +// Check that the result type of an operation within a rewrite can be inferred +// from the type of an operand within the match block. +pdl.pattern @infer_type_from_type_used_in_match : benefit(1) { + %type1 = pdl.type + %type2 = pdl.type + %operand1 = pdl.operand : %type1 + %operand2 = pdl.operand : %type2 + %root = pdl.operation (%operand1, %operand2 : !pdl.value, !pdl.value) + pdl.rewrite %root { + %newOp = pdl.operation "foo.op" -> (%type1, %type2 : !pdl.type, !pdl.type) + } +} + +// ----- + +// Check that the result type of an operation within a rewrite can be inferred +// from the types of operands within the match block. +pdl.pattern @infer_type_from_type_used_in_match : benefit(1) { + %types = pdl.types + %operands = pdl.operands : %types + %root = pdl.operation (%operands : !pdl.range) + pdl.rewrite %root { + %newOp = pdl.operation "foo.op" -> (%types : !pdl.range) + } +} + +// ----- + pdl.pattern @apply_rewrite_with_no_results : benefit(1) { %root = pdl.operation pdl.rewrite %root {