diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -689,9 +689,12 @@ } /// Register the given operations as legal. + void addLegalOp(OperationName op) { + setOpAction(op, LegalizationAction::Legal); + } template void addLegalOp() { - setOpAction(LegalizationAction::Legal); + addLegalOp(OperationName(OpT::getOperationName(), &ctx)); } template void addLegalOp() { @@ -701,11 +704,15 @@ /// Register the given operation as dynamically legal and set the dynamic /// legalization callback to the one provided. + void addDynamicallyLegalOp(OperationName op, + const DynamicLegalityCallbackFn &callback) { + setOpAction(op, LegalizationAction::Dynamic); + setLegalityCallback(op, callback); + } template void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) { - OperationName opName(OpT::getOperationName(), &ctx); - setOpAction(opName, LegalizationAction::Dynamic); - setLegalityCallback(opName, callback); + addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx), + callback); } template void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) { @@ -722,9 +729,12 @@ /// Register the given operation as illegal, i.e. this operation is known to /// not be supported by this target. + void addIllegalOp(OperationName op) { + setOpAction(op, LegalizationAction::Illegal); + } template void addIllegalOp() { - setOpAction(LegalizationAction::Illegal); + addIllegalOp(OperationName(OpT::getOperationName(), &ctx)); } template void addIllegalOp() { @@ -737,6 +747,8 @@ /// addition to the operation itself, all of the operations nested within are /// also considered legal. An optional dynamic legality callback may be /// provided to mark subsets of legal instances as recursively legal. + void markOpRecursivelyLegal(OperationName name, + const DynamicLegalityCallbackFn &callback); template void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback = {}) { OperationName opName(OpT::getOperationName(), &ctx); @@ -840,11 +852,6 @@ /// Set the dynamic legality callback for the unknown ops. void setLegalityCallback(const DynamicLegalityCallbackFn &callback); - /// Set the recursive legality callback for the given operation and mark the - /// operation as recursively legal. - void markOpRecursivelyLegal(OperationName name, - const DynamicLegalityCallbackFn &callback); - /// The set of information that configures the legalization of an operation. struct LegalizationInfo { /// The legality action this operation was given. diff --git a/mlir/test/Transforms/test-rewrite-dynamic-op.mlir b/mlir/test/Transforms/test-rewrite-dynamic-op.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/test-rewrite-dynamic-op.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-opt %s -test-rewrite-dynamic-op | FileCheck %s + +// Test that `test.one_operand_two_results` is replaced with +// `test.generic_dynamic_op`. + +// CHECK-LABEL: func @rewrite_dynamic_op +func @rewrite_dynamic_op(%arg0: i32) { + // CHECK-NEXT: %{{.*}}:2 = "test.dynamic_generic"(%arg0) : (i32) -> (i32, i32) + %0:2 = "test.dynamic_one_operand_two_results"(%arg0) : (i32) -> (i32, i32) + // CHECK-NEXT: return + return +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -955,6 +955,60 @@ }; } // namespace +//===----------------------------------------------------------------------===// +// Test patterns that uses operations and types defined at runtime +//===----------------------------------------------------------------------===// + +namespace { +/// This pattern matches dynamic operations 'test.one_operand_two_results' and +/// replace them with dynamic operations 'test.generic_dynamic_op'. +struct RewriteDynamicOp : public RewritePattern { + RewriteDynamicOp(MLIRContext *context) + : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1, + context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + assert(op->getName().getStringRef() == + "test.dynamic_one_operand_two_results" && + "rewrite pattern should only match operations with the right name"); + + OperationState state(op->getLoc(), "test.dynamic_generic", + op->getOperands(), op->getResultTypes(), + op->getAttrs()); + auto *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; + +struct TestRewriteDynamicOpDriver + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver) + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + StringRef getArgument() const final { return "test-rewrite-dynamic-op"; } + StringRef getDescription() const final { + return "Test rewritting on dynamic operations"; + } + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + + ConversionTarget target(getContext()); + target.addIllegalOp( + OperationName("test.dynamic_one_operand_two_results", &getContext())); + target.addLegalOp(OperationName("test.dynamic_generic", &getContext())); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; +} // end anonymous namespace + //===----------------------------------------------------------------------===// // Test type conversions //===----------------------------------------------------------------------===// @@ -1418,6 +1472,8 @@ PassRegistration(); PassRegistration(); + PassRegistration(); + PassRegistration(); PassRegistration(); }