Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Rewrite/ByteCode.cpp
Show First 20 Lines • Show All 74 Lines • ▼ Show 20 Lines | enum OpCode : ByteCodeField { | ||||
/// Unconditional branch. | /// Unconditional branch. | ||||
Branch, | Branch, | ||||
/// Compare the operand count of an operation with a constant. | /// Compare the operand count of an operation with a constant. | ||||
CheckOperandCount, | CheckOperandCount, | ||||
/// Compare the name of an operation with a constant. | /// Compare the name of an operation with a constant. | ||||
CheckOperationName, | CheckOperationName, | ||||
/// Compare the result count of an operation with a constant. | /// Compare the result count of an operation with a constant. | ||||
CheckResultCount, | CheckResultCount, | ||||
/// Invoke a native creation method. | |||||
CreateNative, | |||||
/// Create an operation. | /// Create an operation. | ||||
CreateOperation, | CreateOperation, | ||||
/// Erase an operation. | /// Erase an operation. | ||||
EraseOp, | EraseOp, | ||||
/// Terminate a matcher or rewrite sequence. | /// Terminate a matcher or rewrite sequence. | ||||
Finalize, | Finalize, | ||||
/// Get a specific attribute of an operation. | /// Get a specific attribute of an operation. | ||||
GetAttribute, | GetAttribute, | ||||
▲ Show 20 Lines • Show All 50 Lines • ▼ Show 20 Lines | |||||
class Generator { | class Generator { | ||||
public: | public: | ||||
Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, | Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, | ||||
SmallVectorImpl<ByteCodeField> &matcherByteCode, | SmallVectorImpl<ByteCodeField> &matcherByteCode, | ||||
SmallVectorImpl<ByteCodeField> &rewriterByteCode, | SmallVectorImpl<ByteCodeField> &rewriterByteCode, | ||||
SmallVectorImpl<PDLByteCodePattern> &patterns, | SmallVectorImpl<PDLByteCodePattern> &patterns, | ||||
ByteCodeField &maxValueMemoryIndex, | ByteCodeField &maxValueMemoryIndex, | ||||
llvm::StringMap<PDLConstraintFunction> &constraintFns, | llvm::StringMap<PDLConstraintFunction> &constraintFns, | ||||
llvm::StringMap<PDLCreateFunction> &createFns, | |||||
llvm::StringMap<PDLRewriteFunction> &rewriteFns) | llvm::StringMap<PDLRewriteFunction> &rewriteFns) | ||||
: ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), | : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), | ||||
rewriterByteCode(rewriterByteCode), patterns(patterns), | rewriterByteCode(rewriterByteCode), patterns(patterns), | ||||
maxValueMemoryIndex(maxValueMemoryIndex) { | maxValueMemoryIndex(maxValueMemoryIndex) { | ||||
for (auto it : llvm::enumerate(constraintFns)) | for (auto it : llvm::enumerate(constraintFns)) | ||||
constraintToMemIndex.try_emplace(it.value().first(), it.index()); | constraintToMemIndex.try_emplace(it.value().first(), it.index()); | ||||
for (auto it : llvm::enumerate(createFns)) | |||||
nativeCreateToMemIndex.try_emplace(it.value().first(), it.index()); | |||||
for (auto it : llvm::enumerate(rewriteFns)) | for (auto it : llvm::enumerate(rewriteFns)) | ||||
externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); | externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); | ||||
} | } | ||||
/// Generate the bytecode for the given PDL interpreter module. | /// Generate the bytecode for the given PDL interpreter module. | ||||
void generate(ModuleOp module); | void generate(ModuleOp module); | ||||
/// Return the memory index to use for the given value. | /// Return the memory index to use for the given value. | ||||
Show All 30 Lines | private: | ||||
void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); | void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); | void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); | void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); | void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); | void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); | void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); | void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); | void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer); | |||||
void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); | void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); | void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); | void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); | void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); | void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); | void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); | void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); | void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); | ||||
Show All 15 Lines | private: | ||||
/// Mapping from the name of an externally registered rewrite to its index in | /// Mapping from the name of an externally registered rewrite to its index in | ||||
/// the bytecode registry. | /// the bytecode registry. | ||||
llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; | llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; | ||||
/// Mapping from the name of an externally registered constraint to its index | /// Mapping from the name of an externally registered constraint to its index | ||||
/// in the bytecode registry. | /// in the bytecode registry. | ||||
llvm::StringMap<ByteCodeField> constraintToMemIndex; | llvm::StringMap<ByteCodeField> constraintToMemIndex; | ||||
/// Mapping from the name of an externally registered creation method to its | |||||
/// index in the bytecode registry. | |||||
llvm::StringMap<ByteCodeField> nativeCreateToMemIndex; | |||||
/// Mapping from rewriter function name to the bytecode address of the | /// Mapping from rewriter function name to the bytecode address of the | ||||
/// rewriter function in byte. | /// rewriter function in byte. | ||||
llvm::StringMap<ByteCodeAddr> rewriterToAddr; | llvm::StringMap<ByteCodeAddr> rewriterToAddr; | ||||
/// Mapping from a uniqued storage object to its memory index within | /// Mapping from a uniqued storage object to its memory index within | ||||
/// `uniquedData`. | /// `uniquedData`. | ||||
DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; | DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; | ||||
▲ Show 20 Lines • Show All 237 Lines • ▼ Show 20 Lines | |||||
void Generator::generate(Operation *op, ByteCodeWriter &writer) { | void Generator::generate(Operation *op, ByteCodeWriter &writer) { | ||||
TypeSwitch<Operation *>(op) | TypeSwitch<Operation *>(op) | ||||
.Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, | .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, | ||||
pdl_interp::AreEqualOp, pdl_interp::BranchOp, | pdl_interp::AreEqualOp, pdl_interp::BranchOp, | ||||
pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, | pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, | ||||
pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, | pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, | ||||
pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp, | pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp, | ||||
pdl_interp::CreateNativeOp, pdl_interp::CreateOperationOp, | pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp, | ||||
pdl_interp::CreateTypeOp, pdl_interp::EraseOp, | pdl_interp::EraseOp, pdl_interp::FinalizeOp, | ||||
pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp, | pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp, | ||||
pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, | pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp, | ||||
pdl_interp::GetOperandOp, pdl_interp::GetResultOp, | pdl_interp::GetResultOp, pdl_interp::GetValueTypeOp, | ||||
pdl_interp::GetValueTypeOp, pdl_interp::InferredTypeOp, | pdl_interp::InferredTypeOp, pdl_interp::IsNotNullOp, | ||||
pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp, | pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp, | ||||
pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp, | pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp, | ||||
pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp, | pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp, | ||||
pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( | pdl_interp::SwitchResultCountOp>( | ||||
[&](auto interpOp) { this->generate(interpOp, writer); }) | [&](auto interpOp) { this->generate(interpOp, writer); }) | ||||
.Default([](Operation *) { | .Default([](Operation *) { | ||||
llvm_unreachable("unknown `pdl_interp` operation"); | llvm_unreachable("unknown `pdl_interp` operation"); | ||||
}); | }); | ||||
} | } | ||||
void Generator::generate(pdl_interp::ApplyConstraintOp op, | void Generator::generate(pdl_interp::ApplyConstraintOp op, | ||||
ByteCodeWriter &writer) { | ByteCodeWriter &writer) { | ||||
assert(constraintToMemIndex.count(op.name()) && | assert(constraintToMemIndex.count(op.name()) && | ||||
"expected index for constraint function"); | "expected index for constraint function"); | ||||
writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], | writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], | ||||
op.constParamsAttr()); | op.constParamsAttr()); | ||||
writer.appendPDLValueList(op.args()); | writer.appendPDLValueList(op.args()); | ||||
writer.append(op.getSuccessors()); | writer.append(op.getSuccessors()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::ApplyRewriteOp op, | void Generator::generate(pdl_interp::ApplyRewriteOp op, | ||||
ByteCodeWriter &writer) { | ByteCodeWriter &writer) { | ||||
assert(externalRewriterToMemIndex.count(op.name()) && | assert(externalRewriterToMemIndex.count(op.name()) && | ||||
"expected index for rewrite function"); | "expected index for rewrite function"); | ||||
writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], | writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], | ||||
op.constParamsAttr(), op.root()); | op.constParamsAttr()); | ||||
writer.appendPDLValueList(op.args()); | writer.appendPDLValueList(op.args()); | ||||
#ifndef NDEBUG | |||||
// In debug mode we also append the number of results so that we can assert | |||||
// that the native creation function gave us the correct number of results. | |||||
writer.append(ByteCodeField(op.results().size())); | |||||
#endif | |||||
for (Value result : op.results()) | |||||
writer.append(result); | |||||
} | } | ||||
void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { | void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { | ||||
writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors()); | writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { | void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { | ||||
writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); | writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); | ||||
} | } | ||||
void Generator::generate(pdl_interp::CheckAttributeOp op, | void Generator::generate(pdl_interp::CheckAttributeOp op, | ||||
Show All 19 Lines | |||||
void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { | void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { | ||||
writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); | writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::CreateAttributeOp op, | void Generator::generate(pdl_interp::CreateAttributeOp op, | ||||
ByteCodeWriter &writer) { | ByteCodeWriter &writer) { | ||||
// Simply repoint the memory index of the result to the constant. | // Simply repoint the memory index of the result to the constant. | ||||
getMemIndex(op.attribute()) = getMemIndex(op.value()); | getMemIndex(op.attribute()) = getMemIndex(op.value()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::CreateNativeOp op, | |||||
ByteCodeWriter &writer) { | |||||
assert(nativeCreateToMemIndex.count(op.name()) && | |||||
"expected index for creation function"); | |||||
writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()], | |||||
op.result(), op.constParamsAttr()); | |||||
writer.appendPDLValueList(op.args()); | |||||
} | |||||
void Generator::generate(pdl_interp::CreateOperationOp op, | void Generator::generate(pdl_interp::CreateOperationOp op, | ||||
ByteCodeWriter &writer) { | ByteCodeWriter &writer) { | ||||
writer.append(OpCode::CreateOperation, op.operation(), | writer.append(OpCode::CreateOperation, op.operation(), | ||||
OperationName(op.name(), ctx), op.operands()); | OperationName(op.name(), ctx), op.operands()); | ||||
// Add the attributes. | // Add the attributes. | ||||
OperandRange attributes = op.attributes(); | OperandRange attributes = op.attributes(); | ||||
writer.append(static_cast<ByteCodeField>(attributes.size())); | writer.append(static_cast<ByteCodeField>(attributes.size())); | ||||
▲ Show 20 Lines • Show All 95 Lines • ▼ Show 20 Lines | |||||
} | } | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// PDLByteCode | // PDLByteCode | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
PDLByteCode::PDLByteCode(ModuleOp module, | PDLByteCode::PDLByteCode(ModuleOp module, | ||||
llvm::StringMap<PDLConstraintFunction> constraintFns, | llvm::StringMap<PDLConstraintFunction> constraintFns, | ||||
llvm::StringMap<PDLCreateFunction> createFns, | |||||
llvm::StringMap<PDLRewriteFunction> rewriteFns) { | llvm::StringMap<PDLRewriteFunction> rewriteFns) { | ||||
Generator generator(module.getContext(), uniquedData, matcherByteCode, | Generator generator(module.getContext(), uniquedData, matcherByteCode, | ||||
rewriterByteCode, patterns, maxValueMemoryIndex, | rewriterByteCode, patterns, maxValueMemoryIndex, | ||||
constraintFns, createFns, rewriteFns); | constraintFns, rewriteFns); | ||||
generator.generate(module); | generator.generate(module); | ||||
// Initialize the external functions. | // Initialize the external functions. | ||||
for (auto &it : constraintFns) | for (auto &it : constraintFns) | ||||
constraintFunctions.push_back(std::move(it.second)); | constraintFunctions.push_back(std::move(it.second)); | ||||
for (auto &it : createFns) | |||||
createFunctions.push_back(std::move(it.second)); | |||||
for (auto &it : rewriteFns) | for (auto &it : rewriteFns) | ||||
rewriteFunctions.push_back(std::move(it.second)); | rewriteFunctions.push_back(std::move(it.second)); | ||||
} | } | ||||
/// Initialize the given state such that it can be used to execute the current | /// Initialize the given state such that it can be used to execute the current | ||||
/// bytecode. | /// bytecode. | ||||
void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { | void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { | ||||
state.memory.resize(maxValueMemoryIndex, nullptr); | state.memory.resize(maxValueMemoryIndex, nullptr); | ||||
Show All 11 Lines | |||||
public: | public: | ||||
ByteCodeExecutor(const ByteCodeField *curCodeIt, | ByteCodeExecutor(const ByteCodeField *curCodeIt, | ||||
MutableArrayRef<const void *> memory, | MutableArrayRef<const void *> memory, | ||||
ArrayRef<const void *> uniquedMemory, | ArrayRef<const void *> uniquedMemory, | ||||
ArrayRef<ByteCodeField> code, | ArrayRef<ByteCodeField> code, | ||||
ArrayRef<PatternBenefit> currentPatternBenefits, | ArrayRef<PatternBenefit> currentPatternBenefits, | ||||
ArrayRef<PDLByteCodePattern> patterns, | ArrayRef<PDLByteCodePattern> patterns, | ||||
ArrayRef<PDLConstraintFunction> constraintFunctions, | ArrayRef<PDLConstraintFunction> constraintFunctions, | ||||
ArrayRef<PDLCreateFunction> createFunctions, | |||||
ArrayRef<PDLRewriteFunction> rewriteFunctions) | ArrayRef<PDLRewriteFunction> rewriteFunctions) | ||||
: curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory), | : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory), | ||||
code(code), currentPatternBenefits(currentPatternBenefits), | code(code), currentPatternBenefits(currentPatternBenefits), | ||||
patterns(patterns), constraintFunctions(constraintFunctions), | patterns(patterns), constraintFunctions(constraintFunctions), | ||||
createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {} | rewriteFunctions(rewriteFunctions) {} | ||||
/// Start executing the code at the current bytecode index. `matches` is an | /// Start executing the code at the current bytecode index. `matches` is an | ||||
/// optional field provided when this function is executed in a matching | /// optional field provided when this function is executed in a matching | ||||
/// context. | /// context. | ||||
void execute(PatternRewriter &rewriter, | void execute(PatternRewriter &rewriter, | ||||
SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, | SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, | ||||
Optional<Location> mainRewriteLoc = {}); | Optional<Location> mainRewriteLoc = {}); | ||||
private: | private: | ||||
/// Internal implementation of executing each of the bytecode commands. | /// Internal implementation of executing each of the bytecode commands. | ||||
void executeApplyConstraint(PatternRewriter &rewriter); | void executeApplyConstraint(PatternRewriter &rewriter); | ||||
void executeApplyRewrite(PatternRewriter &rewriter); | void executeApplyRewrite(PatternRewriter &rewriter); | ||||
void executeAreEqual(); | void executeAreEqual(); | ||||
void executeBranch(); | void executeBranch(); | ||||
void executeCheckOperandCount(); | void executeCheckOperandCount(); | ||||
void executeCheckOperationName(); | void executeCheckOperationName(); | ||||
void executeCheckResultCount(); | void executeCheckResultCount(); | ||||
void executeCreateNative(PatternRewriter &rewriter); | |||||
void executeCreateOperation(PatternRewriter &rewriter, | void executeCreateOperation(PatternRewriter &rewriter, | ||||
Location mainRewriteLoc); | Location mainRewriteLoc); | ||||
void executeEraseOp(PatternRewriter &rewriter); | void executeEraseOp(PatternRewriter &rewriter); | ||||
void executeGetAttribute(); | void executeGetAttribute(); | ||||
void executeGetAttributeType(); | void executeGetAttributeType(); | ||||
void executeGetDefiningOp(); | void executeGetDefiningOp(); | ||||
void executeGetOperand(unsigned index); | void executeGetOperand(unsigned index); | ||||
void executeGetResult(unsigned index); | void executeGetResult(unsigned index); | ||||
▲ Show 20 Lines • Show All 109 Lines • ▼ Show 20 Lines | private: | ||||
MutableArrayRef<const void *> memory; | MutableArrayRef<const void *> memory; | ||||
/// References to ByteCode data necessary for execution. | /// References to ByteCode data necessary for execution. | ||||
ArrayRef<const void *> uniquedMemory; | ArrayRef<const void *> uniquedMemory; | ||||
ArrayRef<ByteCodeField> code; | ArrayRef<ByteCodeField> code; | ||||
ArrayRef<PatternBenefit> currentPatternBenefits; | ArrayRef<PatternBenefit> currentPatternBenefits; | ||||
ArrayRef<PDLByteCodePattern> patterns; | ArrayRef<PDLByteCodePattern> patterns; | ||||
ArrayRef<PDLConstraintFunction> constraintFunctions; | ArrayRef<PDLConstraintFunction> constraintFunctions; | ||||
ArrayRef<PDLCreateFunction> createFunctions; | |||||
ArrayRef<PDLRewriteFunction> rewriteFunctions; | ArrayRef<PDLRewriteFunction> rewriteFunctions; | ||||
}; | }; | ||||
/// This class is an instantiation of the PDLResultList that provides access to | |||||
/// the returned results. This API is not on `PDLResultList` to avoid | |||||
/// overexposing access to information specific solely to the ByteCode. | |||||
class ByteCodeRewriteResultList : public PDLResultList { | |||||
public: | |||||
/// Return the list of PDL results. | |||||
MutableArrayRef<PDLValue> getResults() { return results; } | |||||
}; | |||||
} // end anonymous namespace | } // end anonymous namespace | ||||
void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { | void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); | ||||
const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; | const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; | ||||
ArrayAttr constParams = read<ArrayAttr>(); | ArrayAttr constParams = read<ArrayAttr>(); | ||||
SmallVector<PDLValue, 16> args; | SmallVector<PDLValue, 16> args; | ||||
readList<PDLValue>(args); | readList<PDLValue>(args); | ||||
LLVM_DEBUG({ | LLVM_DEBUG({ | ||||
llvm::dbgs() << " * Arguments: "; | llvm::dbgs() << " * Arguments: "; | ||||
llvm::interleaveComma(args, llvm::dbgs()); | llvm::interleaveComma(args, llvm::dbgs()); | ||||
llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; | llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; | ||||
}); | }); | ||||
// Invoke the constraint and jump to the proper destination. | // Invoke the constraint and jump to the proper destination. | ||||
selectJump(succeeded(constraintFn(args, constParams, rewriter))); | selectJump(succeeded(constraintFn(args, constParams, rewriter))); | ||||
} | } | ||||
void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { | void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); | ||||
const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; | const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; | ||||
ArrayAttr constParams = read<ArrayAttr>(); | ArrayAttr constParams = read<ArrayAttr>(); | ||||
Operation *root = read<Operation *>(); | |||||
SmallVector<PDLValue, 16> args; | SmallVector<PDLValue, 16> args; | ||||
readList<PDLValue>(args); | readList<PDLValue>(args); | ||||
LLVM_DEBUG({ | LLVM_DEBUG({ | ||||
llvm::dbgs() << " * Root: " << *root << "\n * Arguments: "; | llvm::dbgs() << " * Arguments: "; | ||||
llvm::interleaveComma(args, llvm::dbgs()); | llvm::interleaveComma(args, llvm::dbgs()); | ||||
llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; | llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; | ||||
}); | }); | ||||
ByteCodeRewriteResultList results; | |||||
rewriteFn(args, constParams, rewriter, results); | |||||
// Store the results in the bytecode memory. | |||||
#ifndef NDEBUG | |||||
ByteCodeField expectedNumberOfResults = read(); | |||||
assert(results.getResults().size() == expectedNumberOfResults && | |||||
"native PDL rewrite function returned unexpected number of results"); | |||||
#endif | |||||
// Invoke the native rewrite function. | // Store the results in the bytecode memory. | ||||
rewriteFn(root, args, constParams, rewriter); | for (PDLValue &result : results.getResults()) { | ||||
LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); | |||||
memory[read()] = result.getAsOpaquePointer(); | |||||
} | |||||
} | } | ||||
void ByteCodeExecutor::executeAreEqual() { | void ByteCodeExecutor::executeAreEqual() { | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); | ||||
const void *lhs = read<const void *>(); | const void *lhs = read<const void *>(); | ||||
const void *rhs = read<const void *>(); | const void *rhs = read<const void *>(); | ||||
LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); | LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); | ||||
Show All 30 Lines | void ByteCodeExecutor::executeCheckResultCount() { | ||||
Operation *op = read<Operation *>(); | Operation *op = read<Operation *>(); | ||||
uint32_t expectedCount = read<uint32_t>(); | uint32_t expectedCount = read<uint32_t>(); | ||||
LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" | LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" | ||||
<< " * Expected: " << expectedCount << "\n"); | << " * Expected: " << expectedCount << "\n"); | ||||
selectJump(op->getNumResults() == expectedCount); | selectJump(op->getNumResults() == expectedCount); | ||||
} | } | ||||
void ByteCodeExecutor::executeCreateNative(PatternRewriter &rewriter) { | |||||
LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n"); | |||||
const PDLCreateFunction &createFn = createFunctions[read()]; | |||||
ByteCodeField resultIndex = read(); | |||||
ArrayAttr constParams = read<ArrayAttr>(); | |||||
SmallVector<PDLValue, 16> args; | |||||
readList<PDLValue>(args); | |||||
LLVM_DEBUG({ | |||||
llvm::dbgs() << " * Arguments: "; | |||||
llvm::interleaveComma(args, llvm::dbgs()); | |||||
llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; | |||||
}); | |||||
PDLValue result = createFn(args, constParams, rewriter); | |||||
memory[resultIndex] = result.getAsOpaquePointer(); | |||||
LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); | |||||
} | |||||
void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, | void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, | ||||
Location mainRewriteLoc) { | Location mainRewriteLoc) { | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); | ||||
unsigned memIndex = read(); | unsigned memIndex = read(); | ||||
OperationState state(mainRewriteLoc, read<OperationName>()); | OperationState state(mainRewriteLoc, read<OperationName>()); | ||||
readList<Value>(state.operands); | readList<Value>(state.operands); | ||||
for (unsigned i = 0, e = read(); i != e; ++i) { | for (unsigned i = 0, e = read(); i != e; ++i) { | ||||
▲ Show 20 Lines • Show All 260 Lines • ▼ Show 20 Lines | case CheckOperandCount: | ||||
executeCheckOperandCount(); | executeCheckOperandCount(); | ||||
break; | break; | ||||
case CheckOperationName: | case CheckOperationName: | ||||
executeCheckOperationName(); | executeCheckOperationName(); | ||||
break; | break; | ||||
case CheckResultCount: | case CheckResultCount: | ||||
executeCheckResultCount(); | executeCheckResultCount(); | ||||
break; | break; | ||||
case CreateNative: | |||||
executeCreateNative(rewriter); | |||||
break; | |||||
case CreateOperation: | case CreateOperation: | ||||
executeCreateOperation(rewriter, *mainRewriteLoc); | executeCreateOperation(rewriter, *mainRewriteLoc); | ||||
break; | break; | ||||
case EraseOp: | case EraseOp: | ||||
executeEraseOp(rewriter); | executeEraseOp(rewriter); | ||||
break; | break; | ||||
case Finalize: | case Finalize: | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); | ||||
▲ Show 20 Lines • Show All 73 Lines • ▼ Show 20 Lines | void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, | ||||
SmallVectorImpl<MatchResult> &matches, | SmallVectorImpl<MatchResult> &matches, | ||||
PDLByteCodeMutableState &state) const { | PDLByteCodeMutableState &state) const { | ||||
// The first memory slot is always the root operation. | // The first memory slot is always the root operation. | ||||
state.memory[0] = op; | state.memory[0] = op; | ||||
// The matcher function always starts at code address 0. | // The matcher function always starts at code address 0. | ||||
ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData, | ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData, | ||||
matcherByteCode, state.currentPatternBenefits, | matcherByteCode, state.currentPatternBenefits, | ||||
patterns, constraintFunctions, createFunctions, | patterns, constraintFunctions, rewriteFunctions); | ||||
rewriteFunctions); | |||||
executor.execute(rewriter, &matches); | executor.execute(rewriter, &matches); | ||||
// Order the found matches by benefit. | // Order the found matches by benefit. | ||||
std::stable_sort(matches.begin(), matches.end(), | std::stable_sort(matches.begin(), matches.end(), | ||||
[](const MatchResult &lhs, const MatchResult &rhs) { | [](const MatchResult &lhs, const MatchResult &rhs) { | ||||
return lhs.benefit > rhs.benefit; | return lhs.benefit > rhs.benefit; | ||||
}); | }); | ||||
} | } | ||||
/// Run the rewriter of the given pattern on the root operation `op`. | /// Run the rewriter of the given pattern on the root operation `op`. | ||||
void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, | void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, | ||||
PDLByteCodeMutableState &state) const { | PDLByteCodeMutableState &state) const { | ||||
// The arguments of the rewrite function are stored at the start of the | // The arguments of the rewrite function are stored at the start of the | ||||
// memory buffer. | // memory buffer. | ||||
llvm::copy(match.values, state.memory.begin()); | llvm::copy(match.values, state.memory.begin()); | ||||
ByteCodeExecutor executor( | ByteCodeExecutor executor(&rewriterByteCode[match.pattern->getRewriterAddr()], | ||||
&rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, | state.memory, uniquedData, rewriterByteCode, | ||||
uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns, | state.currentPatternBenefits, patterns, | ||||
constraintFunctions, createFunctions, rewriteFunctions); | constraintFunctions, rewriteFunctions); | ||||
jpienaar: Ooc why does the executor not use ref for first arg? | |||||
For me it's more of conventional, i.e. it's a buffer of commands so it feels weird to pass a reference to the first value. I generally try to err on the side of "I pass this value in the way I expect to use it", which in this case is always as a buffer. rriddle: For me it's more of conventional, i.e. it's a buffer of commands so it feels weird to pass a… | |||||
executor.execute(rewriter, /*matches=*/nullptr, match.location); | executor.execute(rewriter, /*matches=*/nullptr, match.location); | ||||
} | } |
Ooc why does the executor not use ref for first arg?