Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Rewrite/ByteCode.cpp
Show All 14 Lines | |||||
#include "mlir/Dialect/PDL/IR/PDLTypes.h" | #include "mlir/Dialect/PDL/IR/PDLTypes.h" | ||||
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" | #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" | ||||
#include "mlir/IR/BuiltinOps.h" | #include "mlir/IR/BuiltinOps.h" | ||||
#include "mlir/IR/RegionGraphTraits.h" | #include "mlir/IR/RegionGraphTraits.h" | ||||
#include "llvm/ADT/IntervalMap.h" | #include "llvm/ADT/IntervalMap.h" | ||||
#include "llvm/ADT/PostOrderIterator.h" | #include "llvm/ADT/PostOrderIterator.h" | ||||
#include "llvm/ADT/TypeSwitch.h" | #include "llvm/ADT/TypeSwitch.h" | ||||
#include "llvm/Support/Debug.h" | #include "llvm/Support/Debug.h" | ||||
#include "llvm/Support/Format.h" | |||||
#include "llvm/Support/FormatVariadic.h" | |||||
#include <numeric> | |||||
#define DEBUG_TYPE "pdl-bytecode" | #define DEBUG_TYPE "pdl-bytecode" | ||||
using namespace mlir; | using namespace mlir; | ||||
using namespace mlir::detail; | using namespace mlir::detail; | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// PDLByteCodePattern | // PDLByteCodePattern | ||||
Show All 24 Lines | |||||
/// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds | /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds | ||||
/// to the position of the pattern within the range returned by | /// to the position of the pattern within the range returned by | ||||
/// `PDLByteCode::getPatterns`. | /// `PDLByteCode::getPatterns`. | ||||
void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, | void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, | ||||
PatternBenefit benefit) { | PatternBenefit benefit) { | ||||
currentPatternBenefits[patternIndex] = benefit; | currentPatternBenefits[patternIndex] = benefit; | ||||
} | } | ||||
/// Cleanup any allocated state after a full match/rewrite has been completed. | |||||
/// This method should be called irregardless of whether the match+rewrite was a | |||||
/// success or not. | |||||
void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { | |||||
allocatedTypeRangeMemory.clear(); | |||||
allocatedValueRangeMemory.clear(); | |||||
} | |||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// Bytecode OpCodes | // Bytecode OpCodes | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
namespace { | namespace { | ||||
enum OpCode : ByteCodeField { | enum OpCode : ByteCodeField { | ||||
/// Apply an externally registered constraint. | /// Apply an externally registered constraint. | ||||
ApplyConstraint, | ApplyConstraint, | ||||
/// Apply an externally registered rewrite. | /// Apply an externally registered rewrite. | ||||
ApplyRewrite, | ApplyRewrite, | ||||
/// Check if two generic values are equal. | /// Check if two generic values are equal. | ||||
AreEqual, | AreEqual, | ||||
/// Check if two ranges are equal. | |||||
AreRangesEqual, | |||||
/// Unconditional branch. | /// Unconditional branch. | ||||
Branch, | Branch, | ||||
/// Compare the operand count of an operation with a constant. | /// Compare the operand count of an operation with a constant. | ||||
CheckOperandCount, | CheckOperandCount, | ||||
/// Compare the name of an operation with a constant. | /// Compare the name of an operation with a constant. | ||||
CheckOperationName, | CheckOperationName, | ||||
/// Compare the result count of an operation with a constant. | /// Compare the result count of an operation with a constant. | ||||
CheckResultCount, | CheckResultCount, | ||||
/// Compare a range of types to a constant range of types. | |||||
CheckTypes, | |||||
/// Create an operation. | /// Create an operation. | ||||
CreateOperation, | CreateOperation, | ||||
/// Create a range of types. | |||||
CreateTypes, | |||||
/// Erase an operation. | /// Erase an operation. | ||||
EraseOp, | EraseOp, | ||||
/// Terminate a matcher or rewrite sequence. | /// Terminate a matcher or rewrite sequence. | ||||
Finalize, | Finalize, | ||||
/// Get a specific attribute of an operation. | /// Get a specific attribute of an operation. | ||||
GetAttribute, | GetAttribute, | ||||
/// Get the type of an attribute. | /// Get the type of an attribute. | ||||
GetAttributeType, | GetAttributeType, | ||||
/// Get the defining operation of a value. | /// Get the defining operation of a value. | ||||
GetDefiningOp, | GetDefiningOp, | ||||
/// Get a specific operand of an operation. | /// Get a specific operand of an operation. | ||||
GetOperand0, | GetOperand0, | ||||
GetOperand1, | GetOperand1, | ||||
GetOperand2, | GetOperand2, | ||||
GetOperand3, | GetOperand3, | ||||
GetOperandN, | GetOperandN, | ||||
/// Get a specific operand group of an operation. | |||||
GetOperands, | |||||
/// Get a specific result of an operation. | /// Get a specific result of an operation. | ||||
GetResult0, | GetResult0, | ||||
GetResult1, | GetResult1, | ||||
GetResult2, | GetResult2, | ||||
GetResult3, | GetResult3, | ||||
GetResultN, | GetResultN, | ||||
/// Get a specific result group of an operation. | |||||
GetResults, | |||||
/// Get the type of a value. | /// Get the type of a value. | ||||
GetValueType, | GetValueType, | ||||
/// Get the types of a value range. | |||||
GetValueRangeTypes, | |||||
/// Check if a generic value is not null. | /// Check if a generic value is not null. | ||||
IsNotNull, | IsNotNull, | ||||
/// Record a successful pattern match. | /// Record a successful pattern match. | ||||
RecordMatch, | RecordMatch, | ||||
/// Replace an operation. | /// Replace an operation. | ||||
ReplaceOp, | ReplaceOp, | ||||
/// Compare an attribute with a set of constants. | /// Compare an attribute with a set of constants. | ||||
SwitchAttribute, | SwitchAttribute, | ||||
/// Compare the operand count of an operation with a set of constants. | /// Compare the operand count of an operation with a set of constants. | ||||
SwitchOperandCount, | SwitchOperandCount, | ||||
/// Compare the name of an operation with a set of constants. | /// Compare the name of an operation with a set of constants. | ||||
SwitchOperationName, | SwitchOperationName, | ||||
/// Compare the result count of an operation with a set of constants. | /// Compare the result count of an operation with a set of constants. | ||||
SwitchResultCount, | SwitchResultCount, | ||||
/// Compare a type with a set of constants. | /// Compare a type with a set of constants. | ||||
SwitchType, | SwitchType, | ||||
/// Compare a range of types with a set of constants. | |||||
SwitchTypes, | |||||
}; | }; | ||||
enum class PDLValueKind { Attribute, Operation, Type, Value }; | |||||
} // end anonymous namespace | } // end anonymous namespace | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// ByteCode Generation | // ByteCode Generation | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// Generator | // Generator | ||||
namespace { | namespace { | ||||
struct ByteCodeWriter; | struct ByteCodeWriter; | ||||
/// This class represents the main generator for the pattern bytecode. | /// This class represents the main generator for the pattern bytecode. | ||||
class Generator { | class Generator { | ||||
public: | public: | ||||
Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, | Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, | ||||
SmallVectorImpl<ByteCodeField> &matcherByteCode, | SmallVectorImpl<ByteCodeField> &matcherByteCode, | ||||
SmallVectorImpl<ByteCodeField> &rewriterByteCode, | SmallVectorImpl<ByteCodeField> &rewriterByteCode, | ||||
SmallVectorImpl<PDLByteCodePattern> &patterns, | SmallVectorImpl<PDLByteCodePattern> &patterns, | ||||
ByteCodeField &maxValueMemoryIndex, | ByteCodeField &maxValueMemoryIndex, | ||||
ByteCodeField &maxTypeRangeMemoryIndex, | |||||
ByteCodeField &maxValueRangeMemoryIndex, | |||||
llvm::StringMap<PDLConstraintFunction> &constraintFns, | llvm::StringMap<PDLConstraintFunction> &constraintFns, | ||||
llvm::StringMap<PDLRewriteFunction> &rewriteFns) | llvm::StringMap<PDLRewriteFunction> &rewriteFns) | ||||
: ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), | : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), | ||||
rewriterByteCode(rewriterByteCode), patterns(patterns), | rewriterByteCode(rewriterByteCode), patterns(patterns), | ||||
maxValueMemoryIndex(maxValueMemoryIndex) { | maxValueMemoryIndex(maxValueMemoryIndex), | ||||
maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), | |||||
maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) { | |||||
for (auto it : llvm::enumerate(constraintFns)) | for (auto it : llvm::enumerate(constraintFns)) | ||||
constraintToMemIndex.try_emplace(it.value().first(), it.index()); | constraintToMemIndex.try_emplace(it.value().first(), it.index()); | ||||
for (auto it : llvm::enumerate(rewriteFns)) | for (auto it : llvm::enumerate(rewriteFns)) | ||||
externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); | externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); | ||||
} | } | ||||
/// Generate the bytecode for the given PDL interpreter module. | /// Generate the bytecode for the given PDL interpreter module. | ||||
void generate(ModuleOp module); | void generate(ModuleOp module); | ||||
/// Return the memory index to use for the given value. | /// Return the memory index to use for the given value. | ||||
ByteCodeField &getMemIndex(Value value) { | ByteCodeField &getMemIndex(Value value) { | ||||
assert(valueToMemIndex.count(value) && | assert(valueToMemIndex.count(value) && | ||||
"expected memory index to be assigned"); | "expected memory index to be assigned"); | ||||
return valueToMemIndex[value]; | return valueToMemIndex[value]; | ||||
} | } | ||||
/// Return the range memory index used to store the given range value. | |||||
ByteCodeField &getRangeStorageIndex(Value value) { | |||||
assert(valueToRangeIndex.count(value) && | |||||
"expected range index to be assigned"); | |||||
return valueToRangeIndex[value]; | |||||
} | |||||
/// Return an index to use when referring to the given data that is uniqued in | /// Return an index to use when referring to the given data that is uniqued in | ||||
/// the MLIR context. | /// the MLIR context. | ||||
template <typename T> | template <typename T> | ||||
std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> | std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> | ||||
getMemIndex(T val) { | getMemIndex(T val) { | ||||
const void *opaqueVal = val.getAsOpaquePointer(); | const void *opaqueVal = val.getAsOpaquePointer(); | ||||
// Get or insert a reference to this value. | // Get or insert a reference to this value. | ||||
Show All 15 Lines | private: | ||||
void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); | void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); | void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); | void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); | void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); | void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); | void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); | void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); | void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); | |||||
void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); | void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); | void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); | void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); | |||||
void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); | void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); | void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); | void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); | void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); | void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); | void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); | |||||
void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); | void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); | |||||
void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); | void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); | void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); | void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); | void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); | void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); | void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); | void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer); | |||||
void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); | void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); | void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); | ||||
void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); | void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); | ||||
/// Mapping from value to its corresponding memory index. | /// Mapping from value to its corresponding memory index. | ||||
DenseMap<Value, ByteCodeField> valueToMemIndex; | DenseMap<Value, ByteCodeField> valueToMemIndex; | ||||
/// Mapping from a range value to its corresponding range storage index. | |||||
DenseMap<Value, ByteCodeField> valueToRangeIndex; | |||||
/// Mapping from the name of an externally registered rewrite to its index in | /// Mapping from the name of an externally registered rewrite to its index in | ||||
/// the bytecode registry. | /// the bytecode registry. | ||||
llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; | llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; | ||||
/// Mapping from the name of an externally registered constraint to its index | /// Mapping from the name of an externally registered constraint to its index | ||||
/// in the bytecode registry. | /// in the bytecode registry. | ||||
llvm::StringMap<ByteCodeField> constraintToMemIndex; | llvm::StringMap<ByteCodeField> constraintToMemIndex; | ||||
Show All 9 Lines | private: | ||||
MLIRContext *ctx; | MLIRContext *ctx; | ||||
/// Data of the ByteCode class to be populated. | /// Data of the ByteCode class to be populated. | ||||
std::vector<const void *> &uniquedData; | std::vector<const void *> &uniquedData; | ||||
SmallVectorImpl<ByteCodeField> &matcherByteCode; | SmallVectorImpl<ByteCodeField> &matcherByteCode; | ||||
SmallVectorImpl<ByteCodeField> &rewriterByteCode; | SmallVectorImpl<ByteCodeField> &rewriterByteCode; | ||||
SmallVectorImpl<PDLByteCodePattern> &patterns; | SmallVectorImpl<PDLByteCodePattern> &patterns; | ||||
ByteCodeField &maxValueMemoryIndex; | ByteCodeField &maxValueMemoryIndex; | ||||
ByteCodeField &maxTypeRangeMemoryIndex; | |||||
ByteCodeField &maxValueRangeMemoryIndex; | |||||
}; | }; | ||||
/// This class provides utilities for writing a bytecode stream. | /// This class provides utilities for writing a bytecode stream. | ||||
struct ByteCodeWriter { | struct ByteCodeWriter { | ||||
ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) | ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) | ||||
: bytecode(bytecode), generator(generator) {} | : bytecode(bytecode), generator(generator) {} | ||||
/// Append a field to the bytecode. | /// Append a field to the bytecode. | ||||
Show All 19 Lines | for (Block *successor : successors) { | ||||
unresolvedSuccessorRefs[successor].push_back(bytecode.size()); | unresolvedSuccessorRefs[successor].push_back(bytecode.size()); | ||||
append(ByteCodeAddr(0)); | append(ByteCodeAddr(0)); | ||||
} | } | ||||
} | } | ||||
/// Append a range of values that will be read as generic PDLValues. | /// Append a range of values that will be read as generic PDLValues. | ||||
void appendPDLValueList(OperandRange values) { | void appendPDLValueList(OperandRange values) { | ||||
bytecode.push_back(values.size()); | bytecode.push_back(values.size()); | ||||
for (Value value : values) { | for (Value value : values) | ||||
appendPDLValue(value); | |||||
} | |||||
/// Append a value as a PDLValue. | |||||
void appendPDLValue(Value value) { | |||||
appendPDLValueKind(value); | |||||
append(value); | |||||
} | |||||
/// Append the PDLValue::Kind of the given value. | |||||
void appendPDLValueKind(Value value) { | |||||
// Append the type of the value in addition to the value itself. | // Append the type of the value in addition to the value itself. | ||||
PDLValueKind kind = | PDLValue::Kind kind = | ||||
TypeSwitch<Type, PDLValueKind>(value.getType()) | TypeSwitch<Type, PDLValue::Kind>(value.getType()) | ||||
.Case<pdl::AttributeType>( | .Case<pdl::AttributeType>( | ||||
[](Type) { return PDLValueKind::Attribute; }) | [](Type) { return PDLValue::Kind::Attribute; }) | ||||
.Case<pdl::OperationType>( | .Case<pdl::OperationType>( | ||||
[](Type) { return PDLValueKind::Operation; }) | [](Type) { return PDLValue::Kind::Operation; }) | ||||
.Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; }) | .Case<pdl::RangeType>([](pdl::RangeType rangeTy) { | ||||
.Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; }); | if (rangeTy.getElementType().isa<pdl::TypeType>()) | ||||
return PDLValue::Kind::TypeRange; | |||||
return PDLValue::Kind::ValueRange; | |||||
}) | |||||
.Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; }) | |||||
.Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; }); | |||||
bytecode.push_back(static_cast<ByteCodeField>(kind)); | bytecode.push_back(static_cast<ByteCodeField>(kind)); | ||||
append(value); | |||||
} | |||||
} | } | ||||
/// Check if the given class `T` has an iterator type. | /// Check if the given class `T` has an iterator type. | ||||
template <typename T, typename... Args> | template <typename T, typename... Args> | ||||
using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); | using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); | ||||
/// Append a value that will be stored in a memory slot and not inline within | /// Append a value that will be stored in a memory slot and not inline within | ||||
/// the bytecode. | /// the bytecode. | ||||
Show All 24 Lines | struct ByteCodeWriter { | ||||
DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; | DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; | ||||
/// The underlying bytecode buffer. | /// The underlying bytecode buffer. | ||||
SmallVectorImpl<ByteCodeField> &bytecode; | SmallVectorImpl<ByteCodeField> &bytecode; | ||||
/// The main generator producing PDL. | /// The main generator producing PDL. | ||||
Generator &generator; | Generator &generator; | ||||
}; | }; | ||||
/// This class represents a live range of PDL Interpreter values, containing | |||||
/// information about when values are live within a match/rewrite. | |||||
struct ByteCodeLiveRange { | |||||
jpienaar: I stream output of these I found useful for debugging before. Not needed as part of this. | |||||
SG, will add more debug logging in a followup. rriddle: SG, will add more debug logging in a followup. | |||||
using Set = llvm::IntervalMap<ByteCodeField, char, 16>; | |||||
using Allocator = Set::Allocator; | |||||
ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {} | |||||
/// Union this live range with the one provided. | |||||
void unionWith(const ByteCodeLiveRange &rhs) { | |||||
for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it) | |||||
liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0); | |||||
} | |||||
/// Returns true if this range overlaps with the one provided. | |||||
bool overlaps(const ByteCodeLiveRange &rhs) const { | |||||
return llvm::IntervalMapOverlaps<Set, Set>(liveness, rhs.liveness).valid(); | |||||
} | |||||
/// A map representing the ranges of the match/rewrite that a value is live in | |||||
/// the interpreter. | |||||
llvm::IntervalMap<ByteCodeField, char, 16> liveness; | |||||
/// The type range storage index for this range. | |||||
Optional<unsigned> typeRangeIndex; | |||||
/// The value range storage index for this range. | |||||
Optional<unsigned> valueRangeIndex; | |||||
}; | |||||
} // end anonymous namespace | } // end anonymous namespace | ||||
void Generator::generate(ModuleOp module) { | void Generator::generate(ModuleOp module) { | ||||
FuncOp matcherFunc = module.lookupSymbol<FuncOp>( | FuncOp matcherFunc = module.lookupSymbol<FuncOp>( | ||||
pdl_interp::PDLInterpDialect::getMatcherFunctionName()); | pdl_interp::PDLInterpDialect::getMatcherFunctionName()); | ||||
ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( | ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( | ||||
pdl_interp::PDLInterpDialect::getRewriterModuleName()); | pdl_interp::PDLInterpDialect::getRewriterModuleName()); | ||||
assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); | assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); | ||||
Show All 31 Lines | void Generator::generate(ModuleOp module) { | ||||
} | } | ||||
} | } | ||||
void Generator::allocateMemoryIndices(FuncOp matcherFunc, | void Generator::allocateMemoryIndices(FuncOp matcherFunc, | ||||
ModuleOp rewriterModule) { | ModuleOp rewriterModule) { | ||||
// Rewriters use simplistic allocation scheme that simply assigns an index to | // Rewriters use simplistic allocation scheme that simply assigns an index to | ||||
// each result. | // each result. | ||||
for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { | for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { | ||||
ByteCodeField index = 0; | ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; | ||||
auto processRewriterValue = [&](Value val) { | |||||
valueToMemIndex.try_emplace(val, index++); | |||||
if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) { | |||||
Type elementTy = rangeType.getElementType(); | |||||
if (elementTy.isa<pdl::TypeType>()) | |||||
valueToRangeIndex.try_emplace(val, typeRangeIndex++); | |||||
else if (elementTy.isa<pdl::ValueType>()) | |||||
valueToRangeIndex.try_emplace(val, valueRangeIndex++); | |||||
} | |||||
}; | |||||
for (BlockArgument arg : rewriterFunc.getArguments()) | for (BlockArgument arg : rewriterFunc.getArguments()) | ||||
valueToMemIndex.try_emplace(arg, index++); | processRewriterValue(arg); | ||||
rewriterFunc.getBody().walk([&](Operation *op) { | rewriterFunc.getBody().walk([&](Operation *op) { | ||||
for (Value result : op->getResults()) | for (Value result : op->getResults()) | ||||
valueToMemIndex.try_emplace(result, index++); | processRewriterValue(result); | ||||
}); | }); | ||||
if (index > maxValueMemoryIndex) | if (index > maxValueMemoryIndex) | ||||
maxValueMemoryIndex = index; | maxValueMemoryIndex = index; | ||||
if (typeRangeIndex > maxTypeRangeMemoryIndex) | |||||
maxTypeRangeMemoryIndex = typeRangeIndex; | |||||
if (valueRangeIndex > maxValueRangeMemoryIndex) | |||||
maxValueRangeMemoryIndex = valueRangeIndex; | |||||
} | } | ||||
// The matcher function uses a more sophisticated numbering that tries to | // The matcher function uses a more sophisticated numbering that tries to | ||||
// minimize the number of memory indices assigned. This is done by determining | // minimize the number of memory indices assigned. This is done by determining | ||||
// a live range of the values within the matcher, then the allocation is just | // a live range of the values within the matcher, then the allocation is just | ||||
// finding the minimal number of overlapping live ranges. This is essentially | // finding the minimal number of overlapping live ranges. This is essentially | ||||
// a simplified form of register allocation where we don't necessarily have a | // a simplified form of register allocation where we don't necessarily have a | ||||
// limited number of registers, but we still want to minimize the number used. | // limited number of registers, but we still want to minimize the number used. | ||||
DenseMap<Operation *, ByteCodeField> opToIndex; | DenseMap<Operation *, ByteCodeField> opToIndex; | ||||
matcherFunc.getBody().walk([&](Operation *op) { | matcherFunc.getBody().walk([&](Operation *op) { | ||||
opToIndex.insert(std::make_pair(op, opToIndex.size())); | opToIndex.insert(std::make_pair(op, opToIndex.size())); | ||||
}); | }); | ||||
// Liveness info for each of the defs within the matcher. | // Liveness info for each of the defs within the matcher. | ||||
using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>; | ByteCodeLiveRange::Allocator allocator; | ||||
LivenessSet::Allocator allocator; | DenseMap<Value, ByteCodeLiveRange> valueDefRanges; | ||||
DenseMap<Value, LivenessSet> valueDefRanges; | |||||
// Assign the root operation being matched to slot 0. | // Assign the root operation being matched to slot 0. | ||||
BlockArgument rootOpArg = matcherFunc.getArgument(0); | BlockArgument rootOpArg = matcherFunc.getArgument(0); | ||||
valueToMemIndex[rootOpArg] = 0; | valueToMemIndex[rootOpArg] = 0; | ||||
// Walk each of the blocks, computing the def interval that the value is used. | // Walk each of the blocks, computing the def interval that the value is used. | ||||
Liveness matcherLiveness(matcherFunc); | Liveness matcherLiveness(matcherFunc); | ||||
for (Block &block : matcherFunc.getBody()) { | for (Block &block : matcherFunc.getBody()) { | ||||
const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block); | const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block); | ||||
assert(info && "expected liveness info for block"); | assert(info && "expected liveness info for block"); | ||||
auto processValue = [&](Value value, Operation *firstUseOrDef) { | auto processValue = [&](Value value, Operation *firstUseOrDef) { | ||||
// We don't need to process the root op argument, this value is always | // We don't need to process the root op argument, this value is always | ||||
// assigned to the first memory slot. | // assigned to the first memory slot. | ||||
if (value == rootOpArg) | if (value == rootOpArg) | ||||
return; | return; | ||||
// Set indices for the range of this block that the value is used. | // Set indices for the range of this block that the value is used. | ||||
auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; | auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; | ||||
defRangeIt->second.insert( | defRangeIt->second.liveness.insert( | ||||
opToIndex[firstUseOrDef], | opToIndex[firstUseOrDef], | ||||
opToIndex[info->getEndOperation(value, firstUseOrDef)], | opToIndex[info->getEndOperation(value, firstUseOrDef)], | ||||
/*dummyValue*/ 0); | /*dummyValue*/ 0); | ||||
// Check to see if this value is a range type. | |||||
if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) { | |||||
Type eleType = rangeTy.getElementType(); | |||||
if (eleType.isa<pdl::TypeType>()) | |||||
defRangeIt->second.typeRangeIndex = 0; | |||||
else if (eleType.isa<pdl::ValueType>()) | |||||
defRangeIt->second.valueRangeIndex = 0; | |||||
} | |||||
}; | }; | ||||
// Process the live-ins of this block. | // Process the live-ins of this block. | ||||
for (Value liveIn : info->in()) | for (Value liveIn : info->in()) | ||||
processValue(liveIn, &block.front()); | processValue(liveIn, &block.front()); | ||||
// Process any new defs within this block. | // Process any new defs within this block. | ||||
for (Operation &op : block) | for (Operation &op : block) | ||||
for (Value result : op.getResults()) | for (Value result : op.getResults()) | ||||
processValue(result, &op); | processValue(result, &op); | ||||
} | } | ||||
// Greedily allocate memory slots using the computed def live ranges. | // Greedily allocate memory slots using the computed def live ranges. | ||||
std::vector<LivenessSet> allocatedIndices; | std::vector<ByteCodeLiveRange> allocatedIndices; | ||||
ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0; | |||||
for (auto &defIt : valueDefRanges) { | for (auto &defIt : valueDefRanges) { | ||||
ByteCodeField &memIndex = valueToMemIndex[defIt.first]; | ByteCodeField &memIndex = valueToMemIndex[defIt.first]; | ||||
LivenessSet &defSet = defIt.second; | ByteCodeLiveRange &defRange = defIt.second; | ||||
// Try to allocate to an existing index. | // Try to allocate to an existing index. | ||||
for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) { | for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) { | ||||
LivenessSet &existingIndex = existingIndexIt.value(); | ByteCodeLiveRange &existingRange = existingIndexIt.value(); | ||||
llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps( | if (!defRange.overlaps(existingRange)) { | ||||
defIt.second, existingIndex); | existingRange.unionWith(defRange); | ||||
if (overlaps.valid()) | |||||
continue; | |||||
// Union the range of the def within the existing index. | |||||
for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) | |||||
existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0); | |||||
memIndex = existingIndexIt.index() + 1; | memIndex = existingIndexIt.index() + 1; | ||||
if (defRange.typeRangeIndex) { | |||||
if (!existingRange.typeRangeIndex) | |||||
existingRange.typeRangeIndex = numTypeRanges++; | |||||
valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; | |||||
} else if (defRange.valueRangeIndex) { | |||||
if (!existingRange.valueRangeIndex) | |||||
existingRange.valueRangeIndex = numValueRanges++; | |||||
valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex; | |||||
} | |||||
break; | |||||
} | |||||
} | } | ||||
// If no existing index could be used, add a new one. | // If no existing index could be used, add a new one. | ||||
if (memIndex == 0) { | if (memIndex == 0) { | ||||
allocatedIndices.emplace_back(allocator); | allocatedIndices.emplace_back(allocator); | ||||
for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) | ByteCodeLiveRange &newRange = allocatedIndices.back(); | ||||
allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0); | newRange.unionWith(defRange); | ||||
// Allocate an index for type/value ranges. | |||||
if (defRange.typeRangeIndex) { | |||||
newRange.typeRangeIndex = numTypeRanges; | |||||
valueToRangeIndex[defIt.first] = numTypeRanges++; | |||||
} else if (defRange.valueRangeIndex) { | |||||
newRange.valueRangeIndex = numValueRanges; | |||||
valueToRangeIndex[defIt.first] = numValueRanges++; | |||||
} | |||||
memIndex = allocatedIndices.size(); | memIndex = allocatedIndices.size(); | ||||
++numIndices; | |||||
} | } | ||||
} | } | ||||
// Update the max number of indices. | // Update the max number of indices. | ||||
ByteCodeField numMatcherIndices = allocatedIndices.size() + 1; | if (numIndices > maxValueMemoryIndex) | ||||
if (numMatcherIndices > maxValueMemoryIndex) | maxValueMemoryIndex = numIndices; | ||||
maxValueMemoryIndex = numMatcherIndices; | if (numTypeRanges > maxTypeRangeMemoryIndex) | ||||
maxTypeRangeMemoryIndex = numTypeRanges; | |||||
if (numValueRanges > maxValueRangeMemoryIndex) | |||||
maxValueRangeMemoryIndex = numValueRanges; | |||||
} | } | ||||
void Generator::generate(Operation *op, ByteCodeWriter &writer) { | void Generator::generate(Operation *op, ByteCodeWriter &writer) { | ||||
TypeSwitch<Operation *>(op) | TypeSwitch<Operation *>(op) | ||||
.Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, | .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, | ||||
pdl_interp::AreEqualOp, pdl_interp::BranchOp, | pdl_interp::AreEqualOp, pdl_interp::BranchOp, | ||||
pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, | pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, | ||||
pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, | pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, | ||||
pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp, | pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, | ||||
pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp, | pdl_interp::CreateAttributeOp, pdl_interp::CreateOperationOp, | ||||
pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp, | |||||
pdl_interp::EraseOp, pdl_interp::FinalizeOp, | pdl_interp::EraseOp, pdl_interp::FinalizeOp, | ||||
pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp, | pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp, | ||||
pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp, | pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp, | ||||
pdl_interp::GetResultOp, pdl_interp::GetValueTypeOp, | pdl_interp::GetOperandsOp, pdl_interp::GetResultOp, | ||||
pdl_interp::GetResultsOp, pdl_interp::GetValueTypeOp, | |||||
pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp, | pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp, | ||||
pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp, | pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp, | ||||
pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp, | pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp, | ||||
pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp, | pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp, | ||||
pdl_interp::SwitchResultCountOp>( | pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( | ||||
[&](auto interpOp) { this->generate(interpOp, writer); }) | [&](auto interpOp) { this->generate(interpOp, writer); }) | ||||
.Default([](Operation *) { | .Default([](Operation *) { | ||||
llvm_unreachable("unknown `pdl_interp` operation"); | llvm_unreachable("unknown `pdl_interp` operation"); | ||||
}); | }); | ||||
} | } | ||||
void Generator::generate(pdl_interp::ApplyConstraintOp op, | void Generator::generate(pdl_interp::ApplyConstraintOp op, | ||||
ByteCodeWriter &writer) { | ByteCodeWriter &writer) { | ||||
assert(constraintToMemIndex.count(op.name()) && | assert(constraintToMemIndex.count(op.name()) && | ||||
"expected index for constraint function"); | "expected index for constraint function"); | ||||
writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], | writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], | ||||
op.constParamsAttr()); | op.constParamsAttr()); | ||||
writer.appendPDLValueList(op.args()); | writer.appendPDLValueList(op.args()); | ||||
writer.append(op.getSuccessors()); | writer.append(op.getSuccessors()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::ApplyRewriteOp op, | void Generator::generate(pdl_interp::ApplyRewriteOp op, | ||||
ByteCodeWriter &writer) { | ByteCodeWriter &writer) { | ||||
assert(externalRewriterToMemIndex.count(op.name()) && | assert(externalRewriterToMemIndex.count(op.name()) && | ||||
"expected index for rewrite function"); | "expected index for rewrite function"); | ||||
writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], | writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], | ||||
op.constParamsAttr()); | op.constParamsAttr()); | ||||
writer.appendPDLValueList(op.args()); | writer.appendPDLValueList(op.args()); | ||||
ResultRange results = op.results(); | |||||
writer.append(ByteCodeField(results.size())); | |||||
for (Value result : results) { | |||||
// In debug mode we also record the expected kind of the result, so that we | |||||
// can provide extra verification of the native rewrite function. | |||||
#ifndef NDEBUG | #ifndef NDEBUG | ||||
// In debug mode we also append the number of results so that we can assert | writer.appendPDLValueKind(result); | ||||
// that the native creation function gave us the correct number of results. | |||||
writer.append(ByteCodeField(op.results().size())); | |||||
#endif | #endif | ||||
for (Value result : op.results()) | |||||
// Range results also need to append the range storage index. | |||||
if (result.getType().isa<pdl::RangeType>()) | |||||
writer.append(getRangeStorageIndex(result)); | |||||
writer.append(result); | writer.append(result); | ||||
} | } | ||||
} | |||||
void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { | void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { | ||||
writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors()); | Value lhs = op.lhs(); | ||||
if (lhs.getType().isa<pdl::RangeType>()) { | |||||
writer.append(OpCode::AreRangesEqual); | |||||
writer.appendPDLValueKind(lhs); | |||||
writer.append(op.lhs(), op.rhs(), op.getSuccessors()); | |||||
return; | |||||
} | |||||
writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors()); | |||||
} | } | ||||
void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { | void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { | ||||
writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); | writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); | ||||
} | } | ||||
void Generator::generate(pdl_interp::CheckAttributeOp op, | void Generator::generate(pdl_interp::CheckAttributeOp op, | ||||
ByteCodeWriter &writer) { | ByteCodeWriter &writer) { | ||||
writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(), | writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(), | ||||
op.getSuccessors()); | op.getSuccessors()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::CheckOperandCountOp op, | void Generator::generate(pdl_interp::CheckOperandCountOp op, | ||||
ByteCodeWriter &writer) { | ByteCodeWriter &writer) { | ||||
writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), | writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), | ||||
static_cast<ByteCodeField>(op.compareAtLeast()), | |||||
op.getSuccessors()); | op.getSuccessors()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::CheckOperationNameOp op, | void Generator::generate(pdl_interp::CheckOperationNameOp op, | ||||
ByteCodeWriter &writer) { | ByteCodeWriter &writer) { | ||||
writer.append(OpCode::CheckOperationName, op.operation(), | writer.append(OpCode::CheckOperationName, op.operation(), | ||||
OperationName(op.name(), ctx), op.getSuccessors()); | OperationName(op.name(), ctx), op.getSuccessors()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::CheckResultCountOp op, | void Generator::generate(pdl_interp::CheckResultCountOp op, | ||||
ByteCodeWriter &writer) { | ByteCodeWriter &writer) { | ||||
writer.append(OpCode::CheckResultCount, op.operation(), op.count(), | writer.append(OpCode::CheckResultCount, op.operation(), op.count(), | ||||
static_cast<ByteCodeField>(op.compareAtLeast()), | |||||
op.getSuccessors()); | op.getSuccessors()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { | void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { | ||||
writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); | writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { | |||||
writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors()); | |||||
} | |||||
void Generator::generate(pdl_interp::CreateAttributeOp op, | void Generator::generate(pdl_interp::CreateAttributeOp op, | ||||
ByteCodeWriter &writer) { | ByteCodeWriter &writer) { | ||||
// Simply repoint the memory index of the result to the constant. | // Simply repoint the memory index of the result to the constant. | ||||
getMemIndex(op.attribute()) = getMemIndex(op.value()); | getMemIndex(op.attribute()) = getMemIndex(op.value()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::CreateOperationOp op, | void Generator::generate(pdl_interp::CreateOperationOp op, | ||||
ByteCodeWriter &writer) { | ByteCodeWriter &writer) { | ||||
writer.append(OpCode::CreateOperation, op.operation(), | writer.append(OpCode::CreateOperation, op.operation(), | ||||
OperationName(op.name(), ctx), op.operands()); | OperationName(op.name(), ctx)); | ||||
writer.appendPDLValueList(op.operands()); | |||||
// Add the attributes. | // Add the attributes. | ||||
OperandRange attributes = op.attributes(); | OperandRange attributes = op.attributes(); | ||||
writer.append(static_cast<ByteCodeField>(attributes.size())); | writer.append(static_cast<ByteCodeField>(attributes.size())); | ||||
for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { | for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { | ||||
writer.append( | writer.append( | ||||
Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx), | Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx), | ||||
std::get<1>(it)); | std::get<1>(it)); | ||||
} | } | ||||
writer.append(op.types()); | writer.appendPDLValueList(op.types()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { | void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { | ||||
// Simply repoint the memory index of the result to the constant. | // Simply repoint the memory index of the result to the constant. | ||||
getMemIndex(op.result()) = getMemIndex(op.value()); | getMemIndex(op.result()) = getMemIndex(op.value()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { | |||||
writer.append(OpCode::CreateTypes, op.result(), | |||||
getRangeStorageIndex(op.result()), op.value()); | |||||
} | |||||
void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { | void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { | ||||
writer.append(OpCode::EraseOp, op.operation()); | writer.append(OpCode::EraseOp, op.operation()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { | void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { | ||||
writer.append(OpCode::Finalize); | writer.append(OpCode::Finalize); | ||||
} | } | ||||
void Generator::generate(pdl_interp::GetAttributeOp op, | void Generator::generate(pdl_interp::GetAttributeOp op, | ||||
ByteCodeWriter &writer) { | ByteCodeWriter &writer) { | ||||
writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), | writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), | ||||
Identifier::get(op.name(), ctx)); | Identifier::get(op.name(), ctx)); | ||||
} | } | ||||
void Generator::generate(pdl_interp::GetAttributeTypeOp op, | void Generator::generate(pdl_interp::GetAttributeTypeOp op, | ||||
ByteCodeWriter &writer) { | ByteCodeWriter &writer) { | ||||
writer.append(OpCode::GetAttributeType, op.result(), op.value()); | writer.append(OpCode::GetAttributeType, op.result(), op.value()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::GetDefiningOpOp op, | void Generator::generate(pdl_interp::GetDefiningOpOp op, | ||||
ByteCodeWriter &writer) { | ByteCodeWriter &writer) { | ||||
writer.append(OpCode::GetDefiningOp, op.operation(), op.value()); | writer.append(OpCode::GetDefiningOp, op.operation()); | ||||
writer.appendPDLValue(op.value()); | |||||
} | } | ||||
void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { | void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { | ||||
uint32_t index = op.index(); | uint32_t index = op.index(); | ||||
if (index < 4) | if (index < 4) | ||||
writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); | writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); | ||||
else | else | ||||
writer.append(OpCode::GetOperandN, index); | writer.append(OpCode::GetOperandN, index); | ||||
writer.append(op.operation(), op.value()); | writer.append(op.operation(), op.value()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) { | |||||
Value result = op.value(); | |||||
Optional<uint32_t> index = op.index(); | |||||
writer.append(OpCode::GetOperands, | |||||
index.getValueOr(std::numeric_limits<uint32_t>::max()), | |||||
op.operation()); | |||||
if (result.getType().isa<pdl::RangeType>()) | |||||
writer.append(getRangeStorageIndex(result)); | |||||
else | |||||
writer.append(std::numeric_limits<ByteCodeField>::max()); | |||||
writer.append(result); | |||||
} | |||||
void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { | void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { | ||||
uint32_t index = op.index(); | uint32_t index = op.index(); | ||||
if (index < 4) | if (index < 4) | ||||
writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); | writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); | ||||
else | else | ||||
writer.append(OpCode::GetResultN, index); | writer.append(OpCode::GetResultN, index); | ||||
writer.append(op.operation(), op.value()); | writer.append(op.operation(), op.value()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { | |||||
Value result = op.value(); | |||||
Optional<uint32_t> index = op.index(); | |||||
writer.append(OpCode::GetResults, | |||||
index.getValueOr(std::numeric_limits<uint32_t>::max()), | |||||
op.operation()); | |||||
if (result.getType().isa<pdl::RangeType>()) | |||||
writer.append(getRangeStorageIndex(result)); | |||||
else | |||||
writer.append(std::numeric_limits<ByteCodeField>::max()); | |||||
writer.append(result); | |||||
} | |||||
void Generator::generate(pdl_interp::GetValueTypeOp op, | void Generator::generate(pdl_interp::GetValueTypeOp op, | ||||
ByteCodeWriter &writer) { | ByteCodeWriter &writer) { | ||||
if (op.getType().isa<pdl::RangeType>()) { | |||||
Value result = op.result(); | |||||
writer.append(OpCode::GetValueRangeTypes, result, | |||||
getRangeStorageIndex(result), op.value()); | |||||
} else { | |||||
writer.append(OpCode::GetValueType, op.result(), op.value()); | writer.append(OpCode::GetValueType, op.result(), op.value()); | ||||
} | } | ||||
} | |||||
void Generator::generate(pdl_interp::InferredTypesOp op, | void Generator::generate(pdl_interp::InferredTypesOp op, | ||||
ByteCodeWriter &writer) { | ByteCodeWriter &writer) { | ||||
// InferType maps to a null type as a marker for inferring result types. | // InferType maps to a null type as a marker for inferring result types. | ||||
getMemIndex(op.type()) = getMemIndex(Type()); | getMemIndex(op.type()) = getMemIndex(Type()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { | void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { | ||||
writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); | writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { | void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { | ||||
ByteCodeField patternIndex = patterns.size(); | ByteCodeField patternIndex = patterns.size(); | ||||
patterns.emplace_back(PDLByteCodePattern::create( | patterns.emplace_back(PDLByteCodePattern::create( | ||||
op, rewriterToAddr[op.rewriter().getLeafReference()])); | op, rewriterToAddr[op.rewriter().getLeafReference()])); | ||||
writer.append(OpCode::RecordMatch, patternIndex, | writer.append(OpCode::RecordMatch, patternIndex, | ||||
SuccessorRange(op.getOperation()), op.matchedOps(), | SuccessorRange(op.getOperation()), op.matchedOps()); | ||||
op.inputs()); | writer.appendPDLValueList(op.inputs()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { | void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { | ||||
writer.append(OpCode::ReplaceOp, op.operation(), op.replValues()); | writer.append(OpCode::ReplaceOp, op.operation()); | ||||
writer.appendPDLValueList(op.replValues()); | |||||
} | } | ||||
void Generator::generate(pdl_interp::SwitchAttributeOp op, | void Generator::generate(pdl_interp::SwitchAttributeOp op, | ||||
ByteCodeWriter &writer) { | ByteCodeWriter &writer) { | ||||
writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(), | writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(), | ||||
op.getSuccessors()); | op.getSuccessors()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::SwitchOperandCountOp op, | void Generator::generate(pdl_interp::SwitchOperandCountOp op, | ||||
ByteCodeWriter &writer) { | ByteCodeWriter &writer) { | ||||
Show All 12 Lines | void Generator::generate(pdl_interp::SwitchResultCountOp op, | ||||
ByteCodeWriter &writer) { | ByteCodeWriter &writer) { | ||||
writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(), | writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(), | ||||
op.getSuccessors()); | op.getSuccessors()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { | void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { | ||||
writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), | writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), | ||||
op.getSuccessors()); | op.getSuccessors()); | ||||
} | } | ||||
void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { | |||||
writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(), | |||||
op.getSuccessors()); | |||||
} | |||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// PDLByteCode | // PDLByteCode | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
PDLByteCode::PDLByteCode(ModuleOp module, | PDLByteCode::PDLByteCode(ModuleOp module, | ||||
llvm::StringMap<PDLConstraintFunction> constraintFns, | llvm::StringMap<PDLConstraintFunction> constraintFns, | ||||
llvm::StringMap<PDLRewriteFunction> rewriteFns) { | llvm::StringMap<PDLRewriteFunction> rewriteFns) { | ||||
Generator generator(module.getContext(), uniquedData, matcherByteCode, | Generator generator(module.getContext(), uniquedData, matcherByteCode, | ||||
rewriterByteCode, patterns, maxValueMemoryIndex, | rewriterByteCode, patterns, maxValueMemoryIndex, | ||||
constraintFns, rewriteFns); | maxTypeRangeCount, maxValueRangeCount, constraintFns, | ||||
rewriteFns); | |||||
generator.generate(module); | generator.generate(module); | ||||
// Initialize the external functions. | // Initialize the external functions. | ||||
for (auto &it : constraintFns) | for (auto &it : constraintFns) | ||||
constraintFunctions.push_back(std::move(it.second)); | constraintFunctions.push_back(std::move(it.second)); | ||||
for (auto &it : rewriteFns) | for (auto &it : rewriteFns) | ||||
rewriteFunctions.push_back(std::move(it.second)); | rewriteFunctions.push_back(std::move(it.second)); | ||||
} | } | ||||
/// Initialize the given state such that it can be used to execute the current | /// Initialize the given state such that it can be used to execute the current | ||||
/// bytecode. | /// bytecode. | ||||
void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { | void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { | ||||
state.memory.resize(maxValueMemoryIndex, nullptr); | state.memory.resize(maxValueMemoryIndex, nullptr); | ||||
state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); | |||||
state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); | |||||
state.currentPatternBenefits.reserve(patterns.size()); | state.currentPatternBenefits.reserve(patterns.size()); | ||||
for (const PDLByteCodePattern &pattern : patterns) | for (const PDLByteCodePattern &pattern : patterns) | ||||
state.currentPatternBenefits.push_back(pattern.getBenefit()); | state.currentPatternBenefits.push_back(pattern.getBenefit()); | ||||
} | } | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// ByteCode Execution | // ByteCode Execution | ||||
namespace { | namespace { | ||||
/// This class provides support for executing a bytecode stream. | /// This class provides support for executing a bytecode stream. | ||||
class ByteCodeExecutor { | class ByteCodeExecutor { | ||||
public: | public: | ||||
ByteCodeExecutor(const ByteCodeField *curCodeIt, | ByteCodeExecutor( | ||||
MutableArrayRef<const void *> memory, | const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory, | ||||
ArrayRef<const void *> uniquedMemory, | MutableArrayRef<TypeRange> typeRangeMemory, | ||||
ArrayRef<ByteCodeField> code, | std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory, | ||||
MutableArrayRef<ValueRange> valueRangeMemory, | |||||
std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory, | |||||
ArrayRef<const void *> uniquedMemory, ArrayRef<ByteCodeField> code, | |||||
ArrayRef<PatternBenefit> currentPatternBenefits, | ArrayRef<PatternBenefit> currentPatternBenefits, | ||||
ArrayRef<PDLByteCodePattern> patterns, | ArrayRef<PDLByteCodePattern> patterns, | ||||
ArrayRef<PDLConstraintFunction> constraintFunctions, | ArrayRef<PDLConstraintFunction> constraintFunctions, | ||||
ArrayRef<PDLRewriteFunction> rewriteFunctions) | ArrayRef<PDLRewriteFunction> rewriteFunctions) | ||||
: curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory), | : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory), | ||||
code(code), currentPatternBenefits(currentPatternBenefits), | allocatedTypeRangeMemory(allocatedTypeRangeMemory), | ||||
patterns(patterns), constraintFunctions(constraintFunctions), | valueRangeMemory(valueRangeMemory), | ||||
allocatedValueRangeMemory(allocatedValueRangeMemory), | |||||
uniquedMemory(uniquedMemory), code(code), | |||||
currentPatternBenefits(currentPatternBenefits), patterns(patterns), | |||||
constraintFunctions(constraintFunctions), | |||||
rewriteFunctions(rewriteFunctions) {} | rewriteFunctions(rewriteFunctions) {} | ||||
/// Start executing the code at the current bytecode index. `matches` is an | /// Start executing the code at the current bytecode index. `matches` is an | ||||
/// optional field provided when this function is executed in a matching | /// optional field provided when this function is executed in a matching | ||||
/// context. | /// context. | ||||
void execute(PatternRewriter &rewriter, | void execute(PatternRewriter &rewriter, | ||||
SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, | SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, | ||||
Optional<Location> mainRewriteLoc = {}); | Optional<Location> mainRewriteLoc = {}); | ||||
private: | private: | ||||
/// Internal implementation of executing each of the bytecode commands. | /// Internal implementation of executing each of the bytecode commands. | ||||
void executeApplyConstraint(PatternRewriter &rewriter); | void executeApplyConstraint(PatternRewriter &rewriter); | ||||
void executeApplyRewrite(PatternRewriter &rewriter); | void executeApplyRewrite(PatternRewriter &rewriter); | ||||
void executeAreEqual(); | void executeAreEqual(); | ||||
void executeAreRangesEqual(); | |||||
void executeBranch(); | void executeBranch(); | ||||
void executeCheckOperandCount(); | void executeCheckOperandCount(); | ||||
void executeCheckOperationName(); | void executeCheckOperationName(); | ||||
void executeCheckResultCount(); | void executeCheckResultCount(); | ||||
void executeCheckTypes(); | |||||
void executeCreateOperation(PatternRewriter &rewriter, | void executeCreateOperation(PatternRewriter &rewriter, | ||||
Location mainRewriteLoc); | Location mainRewriteLoc); | ||||
void executeCreateTypes(); | |||||
void executeEraseOp(PatternRewriter &rewriter); | void executeEraseOp(PatternRewriter &rewriter); | ||||
void executeGetAttribute(); | void executeGetAttribute(); | ||||
void executeGetAttributeType(); | void executeGetAttributeType(); | ||||
void executeGetDefiningOp(); | void executeGetDefiningOp(); | ||||
void executeGetOperand(unsigned index); | void executeGetOperand(unsigned index); | ||||
void executeGetOperands(); | |||||
void executeGetResult(unsigned index); | void executeGetResult(unsigned index); | ||||
void executeGetResults(); | |||||
void executeGetValueType(); | void executeGetValueType(); | ||||
void executeGetValueRangeTypes(); | |||||
void executeIsNotNull(); | void executeIsNotNull(); | ||||
void executeRecordMatch(PatternRewriter &rewriter, | void executeRecordMatch(PatternRewriter &rewriter, | ||||
SmallVectorImpl<PDLByteCode::MatchResult> &matches); | SmallVectorImpl<PDLByteCode::MatchResult> &matches); | ||||
void executeReplaceOp(PatternRewriter &rewriter); | void executeReplaceOp(PatternRewriter &rewriter); | ||||
void executeSwitchAttribute(); | void executeSwitchAttribute(); | ||||
void executeSwitchOperandCount(); | void executeSwitchOperandCount(); | ||||
void executeSwitchOperationName(); | void executeSwitchOperationName(); | ||||
void executeSwitchResultCount(); | void executeSwitchResultCount(); | ||||
void executeSwitchType(); | void executeSwitchType(); | ||||
void executeSwitchTypes(); | |||||
/// Read a value from the bytecode buffer, optionally skipping a certain | /// Read a value from the bytecode buffer, optionally skipping a certain | ||||
/// number of prefix values. These methods always update the buffer to point | /// number of prefix values. These methods always update the buffer to point | ||||
/// to the next field after the read data. | /// to the next field after the read data. | ||||
template <typename T = ByteCodeField> | template <typename T = ByteCodeField> | ||||
T read(size_t skipN = 0) { | T read(size_t skipN = 0) { | ||||
curCodeIt += skipN; | curCodeIt += skipN; | ||||
return readImpl<T>(); | return readImpl<T>(); | ||||
} | } | ||||
ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } | ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } | ||||
/// Read a list of values from the bytecode buffer. | /// Read a list of values from the bytecode buffer. | ||||
template <typename ValueT, typename T> | template <typename ValueT, typename T> | ||||
void readList(SmallVectorImpl<T> &list) { | void readList(SmallVectorImpl<T> &list) { | ||||
list.clear(); | list.clear(); | ||||
for (unsigned i = 0, e = read(); i != e; ++i) | for (unsigned i = 0, e = read(); i != e; ++i) | ||||
list.push_back(read<ValueT>()); | list.push_back(read<ValueT>()); | ||||
} | } | ||||
/// Read a list of values from the bytecode buffer. The values may be encoded | |||||
/// as either Value or ValueRange elements. | |||||
void readValueList(SmallVectorImpl<Value> &list) { | |||||
for (unsigned i = 0, e = read(); i != e; ++i) { | |||||
if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { | |||||
list.push_back(read<Value>()); | |||||
} else { | |||||
ValueRange *values = read<ValueRange *>(); | |||||
list.append(values->begin(), values->end()); | |||||
} | |||||
} | |||||
} | |||||
/// Jump to a specific successor based on a predicate value. | /// Jump to a specific successor based on a predicate value. | ||||
void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } | void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } | ||||
/// Jump to a specific successor based on a destination index. | /// Jump to a specific successor based on a destination index. | ||||
void selectJump(size_t destIndex) { | void selectJump(size_t destIndex) { | ||||
curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; | curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; | ||||
} | } | ||||
/// Handle a switch operation with the provided value and cases. | /// Handle a switch operation with the provided value and cases. | ||||
template <typename T, typename RangeT> | template <typename T, typename RangeT, typename Comparator = std::equal_to<T>> | ||||
void handleSwitch(const T &value, RangeT &&cases) { | void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) { | ||||
LLVM_DEBUG({ | LLVM_DEBUG({ | ||||
llvm::dbgs() << " * Value: " << value << "\n" | llvm::dbgs() << " * Value: " << value << "\n" | ||||
<< " * Cases: "; | << " * Cases: "; | ||||
llvm::interleaveComma(cases, llvm::dbgs()); | llvm::interleaveComma(cases, llvm::dbgs()); | ||||
llvm::dbgs() << "\n"; | llvm::dbgs() << "\n"; | ||||
}); | }); | ||||
// Check to see if the attribute value is within the case list. Jump to | // Check to see if the attribute value is within the case list. Jump to | ||||
// the correct successor index based on the result. | // the correct successor index based on the result. | ||||
for (auto it = cases.begin(), e = cases.end(); it != e; ++it) | for (auto it = cases.begin(), e = cases.end(); it != e; ++it) | ||||
if (*it == value) | if (cmp(*it, value)) | ||||
return selectJump(size_t((it - cases.begin()) + 1)); | return selectJump(size_t((it - cases.begin()) + 1)); | ||||
selectJump(size_t(0)); | selectJump(size_t(0)); | ||||
} | } | ||||
/// Internal implementation of reading various data types from the bytecode | /// Internal implementation of reading various data types from the bytecode | ||||
/// stream. | /// stream. | ||||
template <typename T> | template <typename T> | ||||
const void *readFromMemory() { | const void *readFromMemory() { | ||||
size_t index = *curCodeIt++; | size_t index = *curCodeIt++; | ||||
// If this type is an SSA value, it can only be stored in non-const memory. | // If this type is an SSA value, it can only be stored in non-const memory. | ||||
if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size()) | if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *, | ||||
Value>::value || | |||||
index < memory.size()) | |||||
return memory[index]; | return memory[index]; | ||||
// Otherwise, if this index is not inbounds it is uniqued. | // Otherwise, if this index is not inbounds it is uniqued. | ||||
return uniquedMemory[index - memory.size()]; | return uniquedMemory[index - memory.size()]; | ||||
} | } | ||||
template <typename T> | template <typename T> | ||||
std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { | std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { | ||||
return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); | return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); | ||||
} | } | ||||
template <typename T> | template <typename T> | ||||
std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, | std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, | ||||
T> | T> | ||||
readImpl() { | readImpl() { | ||||
return T(T::getFromOpaquePointer(readFromMemory<T>())); | return T(T::getFromOpaquePointer(readFromMemory<T>())); | ||||
} | } | ||||
template <typename T> | template <typename T> | ||||
std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { | std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { | ||||
switch (static_cast<PDLValueKind>(read())) { | switch (read<PDLValue::Kind>()) { | ||||
case PDLValueKind::Attribute: | case PDLValue::Kind::Attribute: | ||||
return read<Attribute>(); | return read<Attribute>(); | ||||
case PDLValueKind::Operation: | case PDLValue::Kind::Operation: | ||||
return read<Operation *>(); | return read<Operation *>(); | ||||
case PDLValueKind::Type: | case PDLValue::Kind::Type: | ||||
return read<Type>(); | return read<Type>(); | ||||
case PDLValueKind::Value: | case PDLValue::Kind::Value: | ||||
return read<Value>(); | return read<Value>(); | ||||
case PDLValue::Kind::TypeRange: | |||||
return read<TypeRange *>(); | |||||
case PDLValue::Kind::ValueRange: | |||||
return read<ValueRange *>(); | |||||
} | } | ||||
llvm_unreachable("unhandled PDLValueKind"); | llvm_unreachable("unhandled PDLValue::Kind"); | ||||
} | } | ||||
template <typename T> | template <typename T> | ||||
std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { | std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { | ||||
static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, | static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, | ||||
"unexpected ByteCode address size"); | "unexpected ByteCode address size"); | ||||
ByteCodeAddr result; | ByteCodeAddr result; | ||||
std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); | std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); | ||||
curCodeIt += 2; | curCodeIt += 2; | ||||
return result; | return result; | ||||
} | } | ||||
template <typename T> | template <typename T> | ||||
std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { | std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { | ||||
return *curCodeIt++; | return *curCodeIt++; | ||||
} | } | ||||
template <typename T> | |||||
std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() { | |||||
return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>()); | |||||
} | |||||
/// The underlying bytecode buffer. | /// The underlying bytecode buffer. | ||||
const ByteCodeField *curCodeIt; | const ByteCodeField *curCodeIt; | ||||
/// The current execution memory. | /// The current execution memory. | ||||
MutableArrayRef<const void *> memory; | MutableArrayRef<const void *> memory; | ||||
MutableArrayRef<TypeRange> typeRangeMemory; | |||||
std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory; | |||||
MutableArrayRef<ValueRange> valueRangeMemory; | |||||
std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory; | |||||
/// References to ByteCode data necessary for execution. | /// References to ByteCode data necessary for execution. | ||||
ArrayRef<const void *> uniquedMemory; | ArrayRef<const void *> uniquedMemory; | ||||
ArrayRef<ByteCodeField> code; | ArrayRef<ByteCodeField> code; | ||||
ArrayRef<PatternBenefit> currentPatternBenefits; | ArrayRef<PatternBenefit> currentPatternBenefits; | ||||
ArrayRef<PDLByteCodePattern> patterns; | ArrayRef<PDLByteCodePattern> patterns; | ||||
ArrayRef<PDLConstraintFunction> constraintFunctions; | ArrayRef<PDLConstraintFunction> constraintFunctions; | ||||
ArrayRef<PDLRewriteFunction> rewriteFunctions; | ArrayRef<PDLRewriteFunction> rewriteFunctions; | ||||
}; | }; | ||||
/// This class is an instantiation of the PDLResultList that provides access to | /// This class is an instantiation of the PDLResultList that provides access to | ||||
/// the returned results. This API is not on `PDLResultList` to avoid | /// the returned results. This API is not on `PDLResultList` to avoid | ||||
/// overexposing access to information specific solely to the ByteCode. | /// overexposing access to information specific solely to the ByteCode. | ||||
class ByteCodeRewriteResultList : public PDLResultList { | class ByteCodeRewriteResultList : public PDLResultList { | ||||
public: | public: | ||||
ByteCodeRewriteResultList(unsigned maxNumResults) | |||||
: PDLResultList(maxNumResults) {} | |||||
/// Return the list of PDL results. | /// Return the list of PDL results. | ||||
MutableArrayRef<PDLValue> getResults() { return results; } | MutableArrayRef<PDLValue> getResults() { return results; } | ||||
/// Return the type ranges allocated by this list. | |||||
MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { | |||||
return allocatedTypeRanges; | |||||
} | |||||
/// Return the value ranges allocated by this list. | |||||
MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { | |||||
return allocatedValueRanges; | |||||
} | |||||
}; | }; | ||||
} // end anonymous namespace | } // end anonymous namespace | ||||
void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { | void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); | ||||
const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; | const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; | ||||
ArrayAttr constParams = read<ArrayAttr>(); | ArrayAttr constParams = read<ArrayAttr>(); | ||||
SmallVector<PDLValue, 16> args; | SmallVector<PDLValue, 16> args; | ||||
Show All 16 Lines | void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { | ||||
SmallVector<PDLValue, 16> args; | SmallVector<PDLValue, 16> args; | ||||
readList<PDLValue>(args); | readList<PDLValue>(args); | ||||
LLVM_DEBUG({ | LLVM_DEBUG({ | ||||
llvm::dbgs() << " * Arguments: "; | llvm::dbgs() << " * Arguments: "; | ||||
llvm::interleaveComma(args, llvm::dbgs()); | llvm::interleaveComma(args, llvm::dbgs()); | ||||
llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; | llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; | ||||
}); | }); | ||||
ByteCodeRewriteResultList results; | |||||
// Execute the rewrite function. | |||||
ByteCodeField numResults = read(); | |||||
ByteCodeRewriteResultList results(numResults); | |||||
rewriteFn(args, constParams, rewriter, results); | rewriteFn(args, constParams, rewriter, results); | ||||
// Store the results in the bytecode memory. | assert(results.getResults().size() == numResults && | ||||
#ifndef NDEBUG | |||||
ByteCodeField expectedNumberOfResults = read(); | |||||
assert(results.getResults().size() == expectedNumberOfResults && | |||||
"native PDL rewrite function returned unexpected number of results"); | "native PDL rewrite function returned unexpected number of results"); | ||||
#endif | |||||
// Store the results in the bytecode memory. | // Store the results in the bytecode memory. | ||||
for (PDLValue &result : results.getResults()) { | for (PDLValue &result : results.getResults()) { | ||||
LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); | LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); | ||||
// In debug mode we also verify the expected kind of the result. | |||||
#ifndef NDEBUG | |||||
assert(result.getKind() == read<PDLValue::Kind>() && | |||||
"native PDL rewrite function returned an unexpected type of result"); | |||||
#endif | |||||
// If the result is a range, we need to copy it over to the bytecodes | |||||
// range memory. | |||||
if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) { | |||||
unsigned rangeIndex = read(); | |||||
typeRangeMemory[rangeIndex] = *typeRange; | |||||
memory[read()] = &typeRangeMemory[rangeIndex]; | |||||
} else if (Optional<ValueRange> valueRange = | |||||
result.dyn_cast<ValueRange>()) { | |||||
unsigned rangeIndex = read(); | |||||
valueRangeMemory[rangeIndex] = *valueRange; | |||||
memory[read()] = &valueRangeMemory[rangeIndex]; | |||||
} else { | |||||
memory[read()] = result.getAsOpaquePointer(); | memory[read()] = result.getAsOpaquePointer(); | ||||
} | } | ||||
} | } | ||||
// Copy over any underlying storage allocated for result ranges. | |||||
for (auto &it : results.getAllocatedTypeRanges()) | |||||
allocatedTypeRangeMemory.push_back(std::move(it)); | |||||
for (auto &it : results.getAllocatedValueRanges()) | |||||
allocatedValueRangeMemory.push_back(std::move(it)); | |||||
} | |||||
void ByteCodeExecutor::executeAreEqual() { | void ByteCodeExecutor::executeAreEqual() { | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); | ||||
const void *lhs = read<const void *>(); | const void *lhs = read<const void *>(); | ||||
const void *rhs = read<const void *>(); | const void *rhs = read<const void *>(); | ||||
LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); | LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); | ||||
selectJump(lhs == rhs); | selectJump(lhs == rhs); | ||||
} | } | ||||
void ByteCodeExecutor::executeAreRangesEqual() { | |||||
LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n"); | |||||
PDLValue::Kind valueKind = read<PDLValue::Kind>(); | |||||
const void *lhs = read<const void *>(); | |||||
const void *rhs = read<const void *>(); | |||||
switch (valueKind) { | |||||
case PDLValue::Kind::TypeRange: { | |||||
const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs); | |||||
const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs); | |||||
LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); | |||||
selectJump(*lhsRange == *rhsRange); | |||||
break; | |||||
} | |||||
case PDLValue::Kind::ValueRange: { | |||||
const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs); | |||||
const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs); | |||||
LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); | |||||
selectJump(*lhsRange == *rhsRange); | |||||
break; | |||||
} | |||||
default: | |||||
llvm_unreachable("unexpected `AreRangesEqual` value kind"); | |||||
} | |||||
} | |||||
void ByteCodeExecutor::executeBranch() { | void ByteCodeExecutor::executeBranch() { | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); | ||||
curCodeIt = &code[read<ByteCodeAddr>()]; | curCodeIt = &code[read<ByteCodeAddr>()]; | ||||
} | } | ||||
void ByteCodeExecutor::executeCheckOperandCount() { | void ByteCodeExecutor::executeCheckOperandCount() { | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); | ||||
Operation *op = read<Operation *>(); | Operation *op = read<Operation *>(); | ||||
uint32_t expectedCount = read<uint32_t>(); | uint32_t expectedCount = read<uint32_t>(); | ||||
bool compareAtLeast = read(); | |||||
LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" | LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" | ||||
<< " * Expected: " << expectedCount << "\n"); | << " * Expected: " << expectedCount << "\n" | ||||
<< " * Comparator: " | |||||
<< (compareAtLeast ? ">=" : "==") << "\n"); | |||||
if (compareAtLeast) | |||||
selectJump(op->getNumOperands() >= expectedCount); | |||||
else | |||||
selectJump(op->getNumOperands() == expectedCount); | selectJump(op->getNumOperands() == expectedCount); | ||||
} | } | ||||
void ByteCodeExecutor::executeCheckOperationName() { | void ByteCodeExecutor::executeCheckOperationName() { | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); | ||||
Operation *op = read<Operation *>(); | Operation *op = read<Operation *>(); | ||||
OperationName expectedName = read<OperationName>(); | OperationName expectedName = read<OperationName>(); | ||||
LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" | LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" | ||||
<< " * Expected: \"" << expectedName << "\"\n"); | << " * Expected: \"" << expectedName << "\"\n"); | ||||
selectJump(op->getName() == expectedName); | selectJump(op->getName() == expectedName); | ||||
} | } | ||||
void ByteCodeExecutor::executeCheckResultCount() { | void ByteCodeExecutor::executeCheckResultCount() { | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); | ||||
Operation *op = read<Operation *>(); | Operation *op = read<Operation *>(); | ||||
uint32_t expectedCount = read<uint32_t>(); | uint32_t expectedCount = read<uint32_t>(); | ||||
bool compareAtLeast = read(); | |||||
LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" | LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" | ||||
<< " * Expected: " << expectedCount << "\n"); | << " * Expected: " << expectedCount << "\n" | ||||
<< " * Comparator: " | |||||
<< (compareAtLeast ? ">=" : "==") << "\n"); | |||||
if (compareAtLeast) | |||||
selectJump(op->getNumResults() >= expectedCount); | |||||
else | |||||
selectJump(op->getNumResults() == expectedCount); | selectJump(op->getNumResults() == expectedCount); | ||||
} | } | ||||
void ByteCodeExecutor::executeCheckTypes() { | |||||
LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); | |||||
TypeRange *lhs = read<TypeRange *>(); | |||||
Attribute rhs = read<Attribute>(); | |||||
LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); | |||||
selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>()); | |||||
} | |||||
void ByteCodeExecutor::executeCreateTypes() { | |||||
LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); | |||||
unsigned memIndex = read(); | |||||
unsigned rangeIndex = read(); | |||||
ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>(); | |||||
LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); | |||||
// Allocate a buffer for this type range. | |||||
llvm::OwningArrayRef<Type> storage(typesAttr.size()); | |||||
llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin()); | |||||
allocatedTypeRangeMemory.emplace_back(std::move(storage)); | |||||
// Assign this to the range slot and use the range as the value for the | |||||
// memory index. | |||||
typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back(); | |||||
memory[memIndex] = &typeRangeMemory[rangeIndex]; | |||||
} | |||||
void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, | void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, | ||||
Location mainRewriteLoc) { | Location mainRewriteLoc) { | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); | ||||
unsigned memIndex = read(); | unsigned memIndex = read(); | ||||
OperationState state(mainRewriteLoc, read<OperationName>()); | OperationState state(mainRewriteLoc, read<OperationName>()); | ||||
readList<Value>(state.operands); | readValueList(state.operands); | ||||
for (unsigned i = 0, e = read(); i != e; ++i) { | for (unsigned i = 0, e = read(); i != e; ++i) { | ||||
Identifier name = read<Identifier>(); | Identifier name = read<Identifier>(); | ||||
if (Attribute attr = read<Attribute>()) | if (Attribute attr = read<Attribute>()) | ||||
state.addAttribute(name, attr); | state.addAttribute(name, attr); | ||||
} | } | ||||
bool hasInferredTypes = false; | |||||
for (unsigned i = 0, e = read(); i != e; ++i) { | for (unsigned i = 0, e = read(); i != e; ++i) { | ||||
Type resultType = read<Type>(); | if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { | ||||
hasInferredTypes |= !resultType; | state.types.push_back(read<Type>()); | ||||
state.types.push_back(resultType); | continue; | ||||
} | |||||
// If we find a null range, this signals that the types are infered. | |||||
if (TypeRange *resultTypes = read<TypeRange *>()) { | |||||
state.types.append(resultTypes->begin(), resultTypes->end()); | |||||
continue; | |||||
} | } | ||||
// Handle the case where the operation has inferred types. | // Handle the case where the operation has inferred types. | ||||
if (hasInferredTypes) { | |||||
InferTypeOpInterface::Concept *concept = | InferTypeOpInterface::Concept *concept = | ||||
state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>(); | state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>(); | ||||
// TODO: Handle failure. | // TODO: Handle failure. | ||||
state.types.clear(); | state.types.clear(); | ||||
if (failed(concept->inferReturnTypes( | if (failed(concept->inferReturnTypes( | ||||
state.getContext(), state.location, state.operands, | state.getContext(), state.location, state.operands, | ||||
state.attributes.getDictionary(state.getContext()), state.regions, | state.attributes.getDictionary(state.getContext()), state.regions, | ||||
state.types))) | state.types))) | ||||
return; | return; | ||||
break; | |||||
} | } | ||||
Operation *resultOp = rewriter.createOperation(state); | Operation *resultOp = rewriter.createOperation(state); | ||||
memory[memIndex] = resultOp; | memory[memIndex] = resultOp; | ||||
LLVM_DEBUG({ | LLVM_DEBUG({ | ||||
llvm::dbgs() << " * Attributes: " | llvm::dbgs() << " * Attributes: " | ||||
<< state.attributes.getDictionary(state.getContext()) | << state.attributes.getDictionary(state.getContext()) | ||||
<< "\n * Operands: "; | << "\n * Operands: "; | ||||
llvm::interleaveComma(state.operands, llvm::dbgs()); | llvm::interleaveComma(state.operands, llvm::dbgs()); | ||||
Show All 33 Lines | void ByteCodeExecutor::executeGetAttributeType() { | ||||
LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" | LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" | ||||
<< " * Result: " << type << "\n"); | << " * Result: " << type << "\n"); | ||||
memory[memIndex] = type.getAsOpaquePointer(); | memory[memIndex] = type.getAsOpaquePointer(); | ||||
} | } | ||||
void ByteCodeExecutor::executeGetDefiningOp() { | void ByteCodeExecutor::executeGetDefiningOp() { | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); | ||||
unsigned memIndex = read(); | unsigned memIndex = read(); | ||||
Operation *op = nullptr; | |||||
if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { | |||||
Value value = read<Value>(); | Value value = read<Value>(); | ||||
Operation *op = value ? value.getDefiningOp() : nullptr; | if (value) | ||||
op = value.getDefiningOp(); | |||||
LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); | |||||
} else { | |||||
ValueRange *values = read<ValueRange *>(); | |||||
if (values && !values->empty()) { | |||||
op = values->front().getDefiningOp(); | |||||
} | |||||
LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); | |||||
} | |||||
LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" | LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); | ||||
<< " * Result: " << *op << "\n"); | |||||
memory[memIndex] = op; | memory[memIndex] = op; | ||||
} | } | ||||
void ByteCodeExecutor::executeGetOperand(unsigned index) { | void ByteCodeExecutor::executeGetOperand(unsigned index) { | ||||
Operation *op = read<Operation *>(); | Operation *op = read<Operation *>(); | ||||
unsigned memIndex = read(); | unsigned memIndex = read(); | ||||
Value operand = | Value operand = | ||||
index < op->getNumOperands() ? op->getOperand(index) : Value(); | index < op->getNumOperands() ? op->getOperand(index) : Value(); | ||||
LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" | LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" | ||||
<< " * Index: " << index << "\n" | << " * Index: " << index << "\n" | ||||
<< " * Result: " << operand << "\n"); | << " * Result: " << operand << "\n"); | ||||
memory[memIndex] = operand.getAsOpaquePointer(); | memory[memIndex] = operand.getAsOpaquePointer(); | ||||
} | } | ||||
/// This function is the internal implementation of `GetResults` and | |||||
/// `GetOperands` that provides support for extracting a value range from the | |||||
/// given operation. | |||||
template <template <typename> class AttrSizedSegmentsT, typename RangeT> | |||||
static void * | |||||
executeGetOperandsResults(RangeT values, Operation *op, unsigned index, | |||||
ByteCodeField rangeIndex, StringRef attrSizedSegments, | |||||
MutableArrayRef<ValueRange> &valueRangeMemory) { | |||||
// Check for the sentinel index that signals that all values should be | |||||
// returned. | |||||
if (index == std::numeric_limits<uint32_t>::max()) { | |||||
LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n"); | |||||
// `values` is already the full value range. | |||||
// Otherwise, check to see if this operation uses AttrSizedSegments. | |||||
} else if (op->hasTrait<AttrSizedSegmentsT>()) { | |||||
LLVM_DEBUG(llvm::dbgs() | |||||
<< " * Extracting values from `" << attrSizedSegments << "`\n"); | |||||
auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments); | |||||
if (!segmentAttr || segmentAttr.getNumElements() <= index) | |||||
return nullptr; | |||||
auto segments = segmentAttr.getValues<int32_t>(); | |||||
unsigned startIndex = | |||||
std::accumulate(segments.begin(), segments.begin() + index, 0); | |||||
values = values.slice(startIndex, *std::next(segments.begin(), index)); | |||||
LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", " | |||||
<< *std::next(segments.begin(), index) << "]\n"); | |||||
// Otherwise, assume this is the last operand group of the operation. | |||||
// FIXME: We currently don't support operations with | |||||
// SameVariadicOperandSize/SameVariadicResultSize here given that we don't | |||||
// have a way to detect it's presence. | |||||
} else if (values.size() >= index) { | |||||
LLVM_DEBUG(llvm::dbgs() | |||||
<< " * Treating values as trailing variadic range\n"); | |||||
values = values.drop_front(index); | |||||
// If we couldn't detect a way to compute the values, bail out. | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
// If the range index is valid, we are returning a range. | |||||
if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) { | |||||
valueRangeMemory[rangeIndex] = values; | |||||
return &valueRangeMemory[rangeIndex]; | |||||
} | |||||
// If a range index wasn't provided, the range is required to be non-variadic. | |||||
return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer(); | |||||
} | |||||
void ByteCodeExecutor::executeGetOperands() { | |||||
LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n"); | |||||
unsigned index = read<uint32_t>(); | |||||
Operation *op = read<Operation *>(); | |||||
ByteCodeField rangeIndex = read(); | |||||
void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>( | |||||
op->getOperands(), op, index, rangeIndex, "operand_segment_sizes", | |||||
valueRangeMemory); | |||||
if (!result) | |||||
LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n"); | |||||
memory[read()] = result; | |||||
} | |||||
void ByteCodeExecutor::executeGetResult(unsigned index) { | void ByteCodeExecutor::executeGetResult(unsigned index) { | ||||
Operation *op = read<Operation *>(); | Operation *op = read<Operation *>(); | ||||
unsigned memIndex = read(); | unsigned memIndex = read(); | ||||
OpResult result = | OpResult result = | ||||
index < op->getNumResults() ? op->getResult(index) : OpResult(); | index < op->getNumResults() ? op->getResult(index) : OpResult(); | ||||
LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" | LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" | ||||
<< " * Index: " << index << "\n" | << " * Index: " << index << "\n" | ||||
<< " * Result: " << result << "\n"); | << " * Result: " << result << "\n"); | ||||
memory[memIndex] = result.getAsOpaquePointer(); | memory[memIndex] = result.getAsOpaquePointer(); | ||||
} | } | ||||
void ByteCodeExecutor::executeGetResults() { | |||||
LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n"); | |||||
unsigned index = read<uint32_t>(); | |||||
Operation *op = read<Operation *>(); | |||||
ByteCodeField rangeIndex = read(); | |||||
void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>( | |||||
op->getResults(), op, index, rangeIndex, "result_segment_sizes", | |||||
valueRangeMemory); | |||||
if (!result) | |||||
LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n"); | |||||
memory[read()] = result; | |||||
} | |||||
void ByteCodeExecutor::executeGetValueType() { | void ByteCodeExecutor::executeGetValueType() { | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); | ||||
unsigned memIndex = read(); | unsigned memIndex = read(); | ||||
Value value = read<Value>(); | Value value = read<Value>(); | ||||
Type type = value ? value.getType() : Type(); | Type type = value ? value.getType() : Type(); | ||||
LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" | LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" | ||||
<< " * Result: " << type << "\n"); | << " * Result: " << type << "\n"); | ||||
memory[memIndex] = type.getAsOpaquePointer(); | memory[memIndex] = type.getAsOpaquePointer(); | ||||
} | } | ||||
void ByteCodeExecutor::executeGetValueRangeTypes() { | |||||
LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n"); | |||||
unsigned memIndex = read(); | |||||
unsigned rangeIndex = read(); | |||||
ValueRange *values = read<ValueRange *>(); | |||||
if (!values) { | |||||
LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n"); | |||||
memory[memIndex] = nullptr; | |||||
return; | |||||
} | |||||
LLVM_DEBUG({ | |||||
llvm::dbgs() << " * Values (" << values->size() << "): "; | |||||
llvm::interleaveComma(*values, llvm::dbgs()); | |||||
llvm::dbgs() << "\n * Result: "; | |||||
llvm::interleaveComma(values->getType(), llvm::dbgs()); | |||||
llvm::dbgs() << "\n"; | |||||
}); | |||||
typeRangeMemory[rangeIndex] = values->getType(); | |||||
memory[memIndex] = &typeRangeMemory[rangeIndex]; | |||||
} | |||||
void ByteCodeExecutor::executeIsNotNull() { | void ByteCodeExecutor::executeIsNotNull() { | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); | ||||
const void *value = read<const void *>(); | const void *value = read<const void *>(); | ||||
LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); | LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); | ||||
selectJump(value != nullptr); | selectJump(value != nullptr); | ||||
} | } | ||||
Show All 22 Lines | void ByteCodeExecutor::executeRecordMatch( | ||||
matchLocs.reserve(numMatchLocs); | matchLocs.reserve(numMatchLocs); | ||||
for (unsigned i = 0; i != numMatchLocs; ++i) | for (unsigned i = 0; i != numMatchLocs; ++i) | ||||
matchLocs.push_back(read<Operation *>()->getLoc()); | matchLocs.push_back(read<Operation *>()->getLoc()); | ||||
Location matchLoc = rewriter.getFusedLoc(matchLocs); | Location matchLoc = rewriter.getFusedLoc(matchLocs); | ||||
LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" | LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" | ||||
<< " * Location: " << matchLoc << "\n"); | << " * Location: " << matchLoc << "\n"); | ||||
matches.emplace_back(matchLoc, patterns[patternIndex], benefit); | matches.emplace_back(matchLoc, patterns[patternIndex], benefit); | ||||
readList<const void *>(matches.back().values); | PDLByteCode::MatchResult &match = matches.back(); | ||||
// Record all of the inputs to the match. If any of the inputs are ranges, we | |||||
// will also need to remap the range pointer to memory stored in the match | |||||
// state. | |||||
unsigned numInputs = read(); | |||||
match.values.reserve(numInputs); | |||||
match.typeRangeValues.reserve(numInputs); | |||||
match.valueRangeValues.reserve(numInputs); | |||||
for (unsigned i = 0; i < numInputs; ++i) { | |||||
switch (read<PDLValue::Kind>()) { | |||||
case PDLValue::Kind::TypeRange: | |||||
match.typeRangeValues.push_back(*read<TypeRange *>()); | |||||
match.values.push_back(&match.typeRangeValues.back()); | |||||
break; | |||||
case PDLValue::Kind::ValueRange: | |||||
match.valueRangeValues.push_back(*read<ValueRange *>()); | |||||
match.values.push_back(&match.valueRangeValues.back()); | |||||
break; | |||||
default: | |||||
match.values.push_back(read<const void *>()); | |||||
break; | |||||
} | |||||
} | |||||
curCodeIt = dest; | curCodeIt = dest; | ||||
} | } | ||||
void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { | void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); | ||||
Operation *op = read<Operation *>(); | Operation *op = read<Operation *>(); | ||||
SmallVector<Value, 16> args; | SmallVector<Value, 16> args; | ||||
readList<Value>(args); | readValueList(args); | ||||
LLVM_DEBUG({ | LLVM_DEBUG({ | ||||
llvm::dbgs() << " * Operation: " << *op << "\n" | llvm::dbgs() << " * Operation: " << *op << "\n" | ||||
<< " * Values: "; | << " * Values: "; | ||||
llvm::interleaveComma(args, llvm::dbgs()); | llvm::interleaveComma(args, llvm::dbgs()); | ||||
llvm::dbgs() << "\n"; | llvm::dbgs() << "\n"; | ||||
}); | }); | ||||
rewriter.replaceOp(op, args); | rewriter.replaceOp(op, args); | ||||
▲ Show 20 Lines • Show All 56 Lines • ▼ Show 20 Lines | |||||
void ByteCodeExecutor::executeSwitchType() { | void ByteCodeExecutor::executeSwitchType() { | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); | ||||
Type value = read<Type>(); | Type value = read<Type>(); | ||||
auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); | auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); | ||||
handleSwitch(value, cases); | handleSwitch(value, cases); | ||||
} | } | ||||
void ByteCodeExecutor::executeSwitchTypes() { | |||||
LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n"); | |||||
TypeRange *value = read<TypeRange *>(); | |||||
auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>(); | |||||
if (!value) { | |||||
LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n"); | |||||
return selectJump(size_t(0)); | |||||
} | |||||
handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) { | |||||
return value == caseValue.getAsValueRange<TypeAttr>(); | |||||
}); | |||||
} | |||||
void ByteCodeExecutor::execute( | void ByteCodeExecutor::execute( | ||||
PatternRewriter &rewriter, | PatternRewriter &rewriter, | ||||
SmallVectorImpl<PDLByteCode::MatchResult> *matches, | SmallVectorImpl<PDLByteCode::MatchResult> *matches, | ||||
Optional<Location> mainRewriteLoc) { | Optional<Location> mainRewriteLoc) { | ||||
while (true) { | while (true) { | ||||
OpCode opCode = static_cast<OpCode>(read()); | OpCode opCode = static_cast<OpCode>(read()); | ||||
switch (opCode) { | switch (opCode) { | ||||
case ApplyConstraint: | case ApplyConstraint: | ||||
executeApplyConstraint(rewriter); | executeApplyConstraint(rewriter); | ||||
break; | break; | ||||
case ApplyRewrite: | case ApplyRewrite: | ||||
executeApplyRewrite(rewriter); | executeApplyRewrite(rewriter); | ||||
break; | break; | ||||
case AreEqual: | case AreEqual: | ||||
executeAreEqual(); | executeAreEqual(); | ||||
break; | break; | ||||
case AreRangesEqual: | |||||
executeAreRangesEqual(); | |||||
break; | |||||
case Branch: | case Branch: | ||||
executeBranch(); | executeBranch(); | ||||
break; | break; | ||||
case CheckOperandCount: | case CheckOperandCount: | ||||
executeCheckOperandCount(); | executeCheckOperandCount(); | ||||
break; | break; | ||||
case CheckOperationName: | case CheckOperationName: | ||||
executeCheckOperationName(); | executeCheckOperationName(); | ||||
break; | break; | ||||
case CheckResultCount: | case CheckResultCount: | ||||
executeCheckResultCount(); | executeCheckResultCount(); | ||||
break; | break; | ||||
case CheckTypes: | |||||
executeCheckTypes(); | |||||
break; | |||||
case CreateOperation: | case CreateOperation: | ||||
executeCreateOperation(rewriter, *mainRewriteLoc); | executeCreateOperation(rewriter, *mainRewriteLoc); | ||||
break; | break; | ||||
case CreateTypes: | |||||
executeCreateTypes(); | |||||
break; | |||||
case EraseOp: | case EraseOp: | ||||
executeEraseOp(rewriter); | executeEraseOp(rewriter); | ||||
break; | break; | ||||
case Finalize: | case Finalize: | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); | ||||
return; | return; | ||||
case GetAttribute: | case GetAttribute: | ||||
executeGetAttribute(); | executeGetAttribute(); | ||||
Show All 12 Lines | case GetOperand3: { | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); | ||||
executeGetOperand(index); | executeGetOperand(index); | ||||
break; | break; | ||||
} | } | ||||
case GetOperandN: | case GetOperandN: | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); | ||||
executeGetOperand(read<uint32_t>()); | executeGetOperand(read<uint32_t>()); | ||||
break; | break; | ||||
case GetOperands: | |||||
executeGetOperands(); | |||||
break; | |||||
case GetResult0: | case GetResult0: | ||||
case GetResult1: | case GetResult1: | ||||
case GetResult2: | case GetResult2: | ||||
case GetResult3: { | case GetResult3: { | ||||
unsigned index = opCode - GetResult0; | unsigned index = opCode - GetResult0; | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); | ||||
executeGetResult(index); | executeGetResult(index); | ||||
break; | break; | ||||
} | } | ||||
case GetResultN: | case GetResultN: | ||||
LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); | LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); | ||||
executeGetResult(read<uint32_t>()); | executeGetResult(read<uint32_t>()); | ||||
break; | break; | ||||
case GetResults: | |||||
executeGetResults(); | |||||
break; | |||||
case GetValueType: | case GetValueType: | ||||
executeGetValueType(); | executeGetValueType(); | ||||
break; | break; | ||||
case GetValueRangeTypes: | |||||
executeGetValueRangeTypes(); | |||||
break; | |||||
case IsNotNull: | case IsNotNull: | ||||
executeIsNotNull(); | executeIsNotNull(); | ||||
break; | break; | ||||
case RecordMatch: | case RecordMatch: | ||||
assert(matches && | assert(matches && | ||||
"expected matches to be provided when executing the matcher"); | "expected matches to be provided when executing the matcher"); | ||||
executeRecordMatch(rewriter, *matches); | executeRecordMatch(rewriter, *matches); | ||||
break; | break; | ||||
Show All 10 Lines | case SwitchOperationName: | ||||
executeSwitchOperationName(); | executeSwitchOperationName(); | ||||
break; | break; | ||||
case SwitchResultCount: | case SwitchResultCount: | ||||
executeSwitchResultCount(); | executeSwitchResultCount(); | ||||
break; | break; | ||||
case SwitchType: | case SwitchType: | ||||
executeSwitchType(); | executeSwitchType(); | ||||
break; | break; | ||||
case SwitchTypes: | |||||
executeSwitchTypes(); | |||||
break; | |||||
} | } | ||||
LLVM_DEBUG(llvm::dbgs() << "\n"); | LLVM_DEBUG(llvm::dbgs() << "\n"); | ||||
} | } | ||||
} | } | ||||
/// Run the pattern matcher on the given root operation, collecting the matched | /// Run the pattern matcher on the given root operation, collecting the matched | ||||
/// patterns in `matches`. | /// patterns in `matches`. | ||||
void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, | void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, | ||||
SmallVectorImpl<MatchResult> &matches, | SmallVectorImpl<MatchResult> &matches, | ||||
PDLByteCodeMutableState &state) const { | PDLByteCodeMutableState &state) const { | ||||
// The first memory slot is always the root operation. | // The first memory slot is always the root operation. | ||||
state.memory[0] = op; | state.memory[0] = op; | ||||
// The matcher function always starts at code address 0. | // The matcher function always starts at code address 0. | ||||
ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData, | ByteCodeExecutor executor( | ||||
matcherByteCode, state.currentPatternBenefits, | matcherByteCode.data(), state.memory, state.typeRangeMemory, | ||||
patterns, constraintFunctions, rewriteFunctions); | state.allocatedTypeRangeMemory, state.valueRangeMemory, | ||||
state.allocatedValueRangeMemory, uniquedData, matcherByteCode, | |||||
state.currentPatternBenefits, patterns, constraintFunctions, | |||||
rewriteFunctions); | |||||
executor.execute(rewriter, &matches); | executor.execute(rewriter, &matches); | ||||
// Order the found matches by benefit. | // Order the found matches by benefit. | ||||
std::stable_sort(matches.begin(), matches.end(), | std::stable_sort(matches.begin(), matches.end(), | ||||
[](const MatchResult &lhs, const MatchResult &rhs) { | [](const MatchResult &lhs, const MatchResult &rhs) { | ||||
return lhs.benefit > rhs.benefit; | return lhs.benefit > rhs.benefit; | ||||
}); | }); | ||||
} | } | ||||
/// Run the rewriter of the given pattern on the root operation `op`. | /// Run the rewriter of the given pattern on the root operation `op`. | ||||
void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, | void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, | ||||
PDLByteCodeMutableState &state) const { | PDLByteCodeMutableState &state) const { | ||||
// The arguments of the rewrite function are stored at the start of the | // The arguments of the rewrite function are stored at the start of the | ||||
// memory buffer. | // memory buffer. | ||||
llvm::copy(match.values, state.memory.begin()); | llvm::copy(match.values, state.memory.begin()); | ||||
ByteCodeExecutor executor(&rewriterByteCode[match.pattern->getRewriterAddr()], | ByteCodeExecutor executor( | ||||
state.memory, uniquedData, rewriterByteCode, | &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, | ||||
state.currentPatternBenefits, patterns, | state.typeRangeMemory, state.allocatedTypeRangeMemory, | ||||
state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData, | |||||
rewriterByteCode, state.currentPatternBenefits, patterns, | |||||
constraintFunctions, rewriteFunctions); | constraintFunctions, rewriteFunctions); | ||||
executor.execute(rewriter, /*matches=*/nullptr, match.location); | executor.execute(rewriter, /*matches=*/nullptr, match.location); | ||||
} | } |
I stream output of these I found useful for debugging before. Not needed as part of this.