Changeset View
Changeset View
Standalone View
Standalone View
mlir/test/lib/Rewrite/TestPDLByteCode.cpp
Show All 20 Lines | |||||
} | } | ||||
static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values, | static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values, | ||||
ArrayAttr constantParams, | ArrayAttr constantParams, | ||||
PatternRewriter &rewriter) { | PatternRewriter &rewriter) { | ||||
return customSingleEntityConstraint(values[1], constantParams, rewriter); | return customSingleEntityConstraint(values[1], constantParams, rewriter); | ||||
} | } | ||||
// Custom creator invoked from PDL. | // Custom creator invoked from PDL. | ||||
static PDLValue customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams, | static void customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams, | ||||
PatternRewriter &rewriter) { | PatternRewriter &rewriter, PDLResultList &results) { | ||||
return rewriter.createOperation( | results.push_back(rewriter.createOperation( | ||||
OperationState(args[0].cast<Operation *>()->getLoc(), "test.success")); | OperationState(args[0].cast<Operation *>()->getLoc(), "test.success"))); | ||||
} | } | ||||
/// Custom rewriter invoked from PDL. | /// Custom rewriter invoked from PDL. | ||||
static void customRewriter(Operation *root, ArrayRef<PDLValue> args, | static void customRewriter(ArrayRef<PDLValue> args, ArrayAttr constantParams, | ||||
ArrayAttr constantParams, | PatternRewriter &rewriter, PDLResultList &results) { | ||||
PatternRewriter &rewriter) { | Operation *root = args[0].cast<Operation *>(); | ||||
OperationState successOpState(root->getLoc(), "test.success"); | OperationState successOpState(root->getLoc(), "test.success"); | ||||
successOpState.addOperands(args[0].cast<Value>()); | successOpState.addOperands(args[1].cast<Value>()); | ||||
successOpState.addAttribute("constantParams", constantParams); | successOpState.addAttribute("constantParams", constantParams); | ||||
rewriter.createOperation(successOpState); | rewriter.createOperation(successOpState); | ||||
rewriter.eraseOp(root); | rewriter.eraseOp(root); | ||||
} | } | ||||
namespace { | namespace { | ||||
struct TestPDLByteCodePass | struct TestPDLByteCodePass | ||||
: public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> { | : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> { | ||||
Show All 9 Lines | void runOnOperation() final { | ||||
// Process the pattern module. | // Process the pattern module. | ||||
patternModule.getOperation()->remove(); | patternModule.getOperation()->remove(); | ||||
PDLPatternModule pdlPattern(patternModule); | PDLPatternModule pdlPattern(patternModule); | ||||
pdlPattern.registerConstraintFunction("multi_entity_constraint", | pdlPattern.registerConstraintFunction("multi_entity_constraint", | ||||
customMultiEntityConstraint); | customMultiEntityConstraint); | ||||
pdlPattern.registerConstraintFunction("single_entity_constraint", | pdlPattern.registerConstraintFunction("single_entity_constraint", | ||||
customSingleEntityConstraint); | customSingleEntityConstraint); | ||||
pdlPattern.registerCreateFunction("creator", customCreate); | pdlPattern.registerRewriteFunction("creator", customCreate); | ||||
pdlPattern.registerRewriteFunction("rewriter", customRewriter); | pdlPattern.registerRewriteFunction("rewriter", customRewriter); | ||||
OwningRewritePatternList patternList(std::move(pdlPattern)); | OwningRewritePatternList patternList(std::move(pdlPattern)); | ||||
// Invoke the pattern driver with the provided patterns. | // Invoke the pattern driver with the provided patterns. | ||||
(void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(), | (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(), | ||||
std::move(patternList)); | std::move(patternList)); | ||||
} | } | ||||
Show All 11 Lines |