diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -732,6 +732,34 @@ Optional mainRewriteLoc = {}); private: + /// Internal implementation of executing each of the bytecode commands. + void executeApplyConstraint(PatternRewriter &rewriter); + void executeApplyRewrite(PatternRewriter &rewriter); + void executeAreEqual(); + void executeBranch(); + void executeCheckOperandCount(); + void executeCheckOperationName(); + void executeCheckResultCount(); + void executeCreateNative(PatternRewriter &rewriter); + void executeCreateOperation(PatternRewriter &rewriter, + Location mainRewriteLoc); + void executeEraseOp(PatternRewriter &rewriter); + void executeGetAttribute(); + void executeGetAttributeType(); + void executeGetDefiningOp(); + void executeGetOperand(unsigned index); + void executeGetResult(unsigned index); + void executeGetValueType(); + void executeIsNotNull(); + void executeRecordMatch(PatternRewriter &rewriter, + SmallVectorImpl &matches); + void executeReplaceOp(PatternRewriter &rewriter); + void executeSwitchAttribute(); + void executeSwitchOperandCount(); + void executeSwitchOperationName(); + void executeSwitchResultCount(); + void executeSwitchType(); + /// Read a value from the bytecode buffer, optionally skipping a certain /// number of prefix values. These methods always update the buffer to point /// to the next field after the read data. @@ -764,7 +792,7 @@ llvm::dbgs() << " * Value: " << value << "\n" << " * Cases: "; llvm::interleaveComma(cases, llvm::dbgs()); - llvm::dbgs() << "\n\n"; + llvm::dbgs() << "\n"; }); // Check to see if the attribute value is within the case list. Jump to @@ -843,6 +871,353 @@ }; } // end anonymous namespace +void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { + LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); + const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; + ArrayAttr constParams = read(); + SmallVector args; + readList(args); + + LLVM_DEBUG({ + llvm::dbgs() << " * Arguments: "; + llvm::interleaveComma(args, llvm::dbgs()); + llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; + }); + + // Invoke the constraint and jump to the proper destination. + selectJump(succeeded(constraintFn(args, constParams, rewriter))); +} + +void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { + LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); + const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; + ArrayAttr constParams = read(); + Operation *root = read(); + SmallVector args; + readList(args); + + LLVM_DEBUG({ + llvm::dbgs() << " * Root: " << *root << "\n * Arguments: "; + llvm::interleaveComma(args, llvm::dbgs()); + llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; + }); + + // Invoke the native rewrite function. + rewriteFn(root, args, constParams, rewriter); +} + +void ByteCodeExecutor::executeAreEqual() { + LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); + const void *lhs = read(); + const void *rhs = read(); + + LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); + selectJump(lhs == rhs); +} + +void ByteCodeExecutor::executeBranch() { + LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); + curCodeIt = &code[read()]; +} + +void ByteCodeExecutor::executeCheckOperandCount() { + LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); + Operation *op = read(); + uint32_t expectedCount = read(); + + LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" + << " * Expected: " << expectedCount << "\n"); + selectJump(op->getNumOperands() == expectedCount); +} + +void ByteCodeExecutor::executeCheckOperationName() { + LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); + Operation *op = read(); + OperationName expectedName = read(); + + LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" + << " * Expected: \"" << expectedName << "\"\n"); + selectJump(op->getName() == expectedName); +} + +void ByteCodeExecutor::executeCheckResultCount() { + LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); + Operation *op = read(); + uint32_t expectedCount = read(); + + LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" + << " * Expected: " << expectedCount << "\n"); + selectJump(op->getNumResults() == expectedCount); +} + +void ByteCodeExecutor::executeCreateNative(PatternRewriter &rewriter) { + LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n"); + const PDLCreateFunction &createFn = createFunctions[read()]; + ByteCodeField resultIndex = read(); + ArrayAttr constParams = read(); + SmallVector args; + readList(args); + + LLVM_DEBUG({ + llvm::dbgs() << " * Arguments: "; + llvm::interleaveComma(args, llvm::dbgs()); + llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; + }); + + PDLValue result = createFn(args, constParams, rewriter); + memory[resultIndex] = result.getAsOpaquePointer(); + + LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); +} + +void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, + Location mainRewriteLoc) { + LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); + + unsigned memIndex = read(); + OperationState state(mainRewriteLoc, read()); + readList(state.operands); + for (unsigned i = 0, e = read(); i != e; ++i) { + Identifier name = read(); + if (Attribute attr = read()) + state.addAttribute(name, attr); + } + + bool hasInferredTypes = false; + for (unsigned i = 0, e = read(); i != e; ++i) { + Type resultType = read(); + hasInferredTypes |= !resultType; + state.types.push_back(resultType); + } + + // Handle the case where the operation has inferred types. + if (hasInferredTypes) { + InferTypeOpInterface::Concept *concept = + state.name.getAbstractOperation()->getInterface(); + + // TODO: Handle failure. + SmallVector inferredTypes; + if (failed(concept->inferReturnTypes( + state.getContext(), state.location, state.operands, + state.attributes.getDictionary(state.getContext()), state.regions, + inferredTypes))) + return; + + for (unsigned i = 0, e = state.types.size(); i != e; ++i) + if (!state.types[i]) + state.types[i] = inferredTypes[i]; + } + Operation *resultOp = rewriter.createOperation(state); + memory[memIndex] = resultOp; + + LLVM_DEBUG({ + llvm::dbgs() << " * Attributes: " + << state.attributes.getDictionary(state.getContext()) + << "\n * Operands: "; + llvm::interleaveComma(state.operands, llvm::dbgs()); + llvm::dbgs() << "\n * Result Types: "; + llvm::interleaveComma(state.types, llvm::dbgs()); + llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; + }); +} + +void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { + LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); + Operation *op = read(); + + LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); + rewriter.eraseOp(op); +} + +void ByteCodeExecutor::executeGetAttribute() { + LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); + unsigned memIndex = read(); + Operation *op = read(); + Identifier attrName = read(); + Attribute attr = op->getAttr(attrName); + + LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" + << " * Attribute: " << attrName << "\n" + << " * Result: " << attr << "\n"); + memory[memIndex] = attr.getAsOpaquePointer(); +} + +void ByteCodeExecutor::executeGetAttributeType() { + LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); + unsigned memIndex = read(); + Attribute attr = read(); + Type type = attr ? attr.getType() : Type(); + + LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" + << " * Result: " << type << "\n"); + memory[memIndex] = type.getAsOpaquePointer(); +} + +void ByteCodeExecutor::executeGetDefiningOp() { + LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); + unsigned memIndex = read(); + Value value = read(); + Operation *op = value ? value.getDefiningOp() : nullptr; + + LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" + << " * Result: " << *op << "\n"); + memory[memIndex] = op; +} + +void ByteCodeExecutor::executeGetOperand(unsigned index) { + Operation *op = read(); + unsigned memIndex = read(); + Value operand = + index < op->getNumOperands() ? op->getOperand(index) : Value(); + + LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" + << " * Index: " << index << "\n" + << " * Result: " << operand << "\n"); + memory[memIndex] = operand.getAsOpaquePointer(); +} + +void ByteCodeExecutor::executeGetResult(unsigned index) { + Operation *op = read(); + unsigned memIndex = read(); + OpResult result = + index < op->getNumResults() ? op->getResult(index) : OpResult(); + + LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" + << " * Index: " << index << "\n" + << " * Result: " << result << "\n"); + memory[memIndex] = result.getAsOpaquePointer(); +} + +void ByteCodeExecutor::executeGetValueType() { + LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); + unsigned memIndex = read(); + Value value = read(); + Type type = value ? value.getType() : Type(); + + LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" + << " * Result: " << type << "\n"); + memory[memIndex] = type.getAsOpaquePointer(); +} + +void ByteCodeExecutor::executeIsNotNull() { + LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); + const void *value = read(); + + LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); + selectJump(value != nullptr); +} + +void ByteCodeExecutor::executeRecordMatch( + PatternRewriter &rewriter, + SmallVectorImpl &matches) { + LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); + unsigned patternIndex = read(); + PatternBenefit benefit = currentPatternBenefits[patternIndex]; + const ByteCodeField *dest = &code[read()]; + + // If the benefit of the pattern is impossible, skip the processing of the + // rest of the pattern. + if (benefit.isImpossibleToMatch()) { + LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n"); + curCodeIt = dest; + return; + } + + // Create a fused location containing the locations of each of the + // operations used in the match. This will be used as the location for + // created operations during the rewrite that don't already have an + // explicit location set. + unsigned numMatchLocs = read(); + SmallVector matchLocs; + matchLocs.reserve(numMatchLocs); + for (unsigned i = 0; i != numMatchLocs; ++i) + matchLocs.push_back(read()->getLoc()); + Location matchLoc = rewriter.getFusedLoc(matchLocs); + + LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" + << " * Location: " << matchLoc << "\n"); + matches.emplace_back(matchLoc, patterns[patternIndex], benefit); + readList(matches.back().values); + curCodeIt = dest; +} + +void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { + LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); + Operation *op = read(); + SmallVector args; + readList(args); + + LLVM_DEBUG({ + llvm::dbgs() << " * Operation: " << *op << "\n" + << " * Values: "; + llvm::interleaveComma(args, llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + rewriter.replaceOp(op, args); +} + +void ByteCodeExecutor::executeSwitchAttribute() { + LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); + Attribute value = read(); + ArrayAttr cases = read(); + handleSwitch(value, cases); +} + +void ByteCodeExecutor::executeSwitchOperandCount() { + LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); + Operation *op = read(); + auto cases = read().getValues(); + + LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); + handleSwitch(op->getNumOperands(), cases); +} + +void ByteCodeExecutor::executeSwitchOperationName() { + LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); + OperationName value = read()->getName(); + size_t caseCount = read(); + + // The operation names are stored in-line, so to print them out for + // debugging purposes we need to read the array before executing the + // switch so that we can display all of the possible values. + LLVM_DEBUG({ + const ByteCodeField *prevCodeIt = curCodeIt; + llvm::dbgs() << " * Value: " << value << "\n" + << " * Cases: "; + llvm::interleaveComma( + llvm::map_range(llvm::seq(0, caseCount), + [&](size_t) { return read(); }), + llvm::dbgs()); + llvm::dbgs() << "\n"; + curCodeIt = prevCodeIt; + }); + + // Try to find the switch value within any of the cases. + for (size_t i = 0; i != caseCount; ++i) { + if (read() == value) { + curCodeIt += (caseCount - i - 1); + return selectJump(i + 1); + } + } + selectJump(size_t(0)); +} + +void ByteCodeExecutor::executeSwitchResultCount() { + LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); + Operation *op = read(); + auto cases = read().getValues(); + + LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); + handleSwitch(op->getNumResults(), cases); +} + +void ByteCodeExecutor::executeSwitchType() { + LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); + Type value = read(); + auto cases = read().getAsValueRange(); + handleSwitch(value, cases); +} + void ByteCodeExecutor::execute( PatternRewriter &rewriter, SmallVectorImpl *matches, @@ -850,383 +1225,105 @@ while (true) { OpCode opCode = static_cast(read()); switch (opCode) { - case ApplyConstraint: { - LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); - const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; - ArrayAttr constParams = read(); - SmallVector args; - readList(args); - LLVM_DEBUG({ - llvm::dbgs() << " * Arguments: "; - llvm::interleaveComma(args, llvm::dbgs()); - llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n"; - }); - - // Invoke the constraint and jump to the proper destination. - selectJump(succeeded(constraintFn(args, constParams, rewriter))); + case ApplyConstraint: + executeApplyConstraint(rewriter); break; - } - case ApplyRewrite: { - LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); - const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; - ArrayAttr constParams = read(); - Operation *root = read(); - SmallVector args; - readList(args); - - LLVM_DEBUG({ - llvm::dbgs() << " * Root: " << *root << "\n" - << " * Arguments: "; - llvm::interleaveComma(args, llvm::dbgs()); - llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n"; - }); - rewriteFn(root, args, constParams, rewriter); + case ApplyRewrite: + executeApplyRewrite(rewriter); break; - } - case AreEqual: { - LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); - const void *lhs = read(); - const void *rhs = read(); - - LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); - selectJump(lhs == rhs); + case AreEqual: + executeAreEqual(); break; - } - case Branch: { - LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n\n"); - curCodeIt = &code[read()]; + case Branch: + executeBranch(); break; - } - case CheckOperandCount: { - LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); - Operation *op = read(); - uint32_t expectedCount = read(); - - LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" - << " * Expected: " << expectedCount << "\n\n"); - selectJump(op->getNumOperands() == expectedCount); + case CheckOperandCount: + executeCheckOperandCount(); break; - } - case CheckOperationName: { - LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); - Operation *op = read(); - OperationName expectedName = read(); - - LLVM_DEBUG(llvm::dbgs() - << " * Found: \"" << op->getName() << "\"\n" - << " * Expected: \"" << expectedName << "\"\n\n"); - selectJump(op->getName() == expectedName); + case CheckOperationName: + executeCheckOperationName(); break; - } - case CheckResultCount: { - LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); - Operation *op = read(); - uint32_t expectedCount = read(); - - LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" - << " * Expected: " << expectedCount << "\n\n"); - selectJump(op->getNumResults() == expectedCount); + case CheckResultCount: + executeCheckResultCount(); break; - } - case CreateNative: { - LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n"); - const PDLCreateFunction &createFn = createFunctions[read()]; - ByteCodeField resultIndex = read(); - ArrayAttr constParams = read(); - SmallVector args; - readList(args); - - LLVM_DEBUG({ - llvm::dbgs() << " * Arguments: "; - llvm::interleaveComma(args, llvm::dbgs()); - llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; - }); - - PDLValue result = createFn(args, constParams, rewriter); - memory[resultIndex] = result.getAsOpaquePointer(); - - LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n\n"); + case CreateNative: + executeCreateNative(rewriter); break; - } - case CreateOperation: { - LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); - assert(mainRewriteLoc && "expected rewrite loc to be provided when " - "executing the rewriter bytecode"); - - unsigned memIndex = read(); - OperationState state(*mainRewriteLoc, read()); - readList(state.operands); - for (unsigned i = 0, e = read(); i != e; ++i) { - Identifier name = read(); - if (Attribute attr = read()) - state.addAttribute(name, attr); - } - - bool hasInferredTypes = false; - for (unsigned i = 0, e = read(); i != e; ++i) { - Type resultType = read(); - hasInferredTypes |= !resultType; - state.types.push_back(resultType); - } - - // Handle the case where the operation has inferred types. - if (hasInferredTypes) { - InferTypeOpInterface::Concept *concept = - state.name.getAbstractOperation() - ->getInterface(); - - // TODO: Handle failure. - SmallVector inferredTypes; - if (failed(concept->inferReturnTypes( - state.getContext(), state.location, state.operands, - state.attributes.getDictionary(state.getContext()), - state.regions, inferredTypes))) - return; - - for (unsigned i = 0, e = state.types.size(); i != e; ++i) - if (!state.types[i]) - state.types[i] = inferredTypes[i]; - } - Operation *resultOp = rewriter.createOperation(state); - memory[memIndex] = resultOp; - - LLVM_DEBUG({ - llvm::dbgs() << " * Attributes: " - << state.attributes.getDictionary(state.getContext()) - << "\n * Operands: "; - llvm::interleaveComma(state.operands, llvm::dbgs()); - llvm::dbgs() << "\n * Result Types: "; - llvm::interleaveComma(state.types, llvm::dbgs()); - llvm::dbgs() << "\n * Result: " << *resultOp << "\n\n"; - }); + case CreateOperation: + executeCreateOperation(rewriter, *mainRewriteLoc); break; - } - case EraseOp: { - LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); - Operation *op = read(); - - LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n\n"); - rewriter.eraseOp(op); + case EraseOp: + executeEraseOp(rewriter); break; - } - case Finalize: { + case Finalize: LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); return; - } - case GetAttribute: { - LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); - unsigned memIndex = read(); - Operation *op = read(); - Identifier attrName = read(); - Attribute attr = op->getAttr(attrName); - - LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" - << " * Attribute: " << attrName << "\n" - << " * Result: " << attr << "\n\n"); - memory[memIndex] = attr.getAsOpaquePointer(); + case GetAttribute: + executeGetAttribute(); break; - } - case GetAttributeType: { - LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); - unsigned memIndex = read(); - Attribute attr = read(); - - LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" - << " * Result: " << attr.getType() << "\n\n"); - memory[memIndex] = attr.getType().getAsOpaquePointer(); + case GetAttributeType: + executeGetAttributeType(); break; - } - case GetDefiningOp: { - LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); - unsigned memIndex = read(); - Value value = read(); - Operation *op = value ? value.getDefiningOp() : nullptr; - - LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" - << " * Result: " << *op << "\n\n"); - memory[memIndex] = op; + case GetDefiningOp: + executeGetDefiningOp(); break; - } case GetOperand0: case GetOperand1: case GetOperand2: - case GetOperand3: - case GetOperandN: { - LLVM_DEBUG({ - llvm::dbgs() << "Executing GetOperand" - << (opCode == GetOperandN ? Twine("N") - : Twine(opCode - GetOperand0)) - << ":\n"; - }); - unsigned index = - opCode == GetOperandN ? read() : (opCode - GetOperand0); - Operation *op = read(); - unsigned memIndex = read(); - Value operand = - index < op->getNumOperands() ? op->getOperand(index) : Value(); - - LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" - << " * Index: " << index << "\n" - << " * Result: " << operand << "\n\n"); - memory[memIndex] = operand.getAsOpaquePointer(); + case GetOperand3: { + unsigned index = opCode - GetOperand0; + LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); + executeGetOperand(opCode - GetOperand0); break; } + case GetOperandN: + LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); + executeGetOperand(read()); + break; case GetResult0: case GetResult1: case GetResult2: - case GetResult3: - case GetResultN: { - LLVM_DEBUG({ - llvm::dbgs() << "Executing GetResult" - << (opCode == GetResultN ? Twine("N") - : Twine(opCode - GetResult0)) - << ":\n"; - }); - unsigned index = - opCode == GetResultN ? read() : (opCode - GetResult0); - Operation *op = read(); - unsigned memIndex = read(); - OpResult result = - index < op->getNumResults() ? op->getResult(index) : OpResult(); - - LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" - << " * Index: " << index << "\n" - << " * Result: " << result << "\n\n"); - memory[memIndex] = result.getAsOpaquePointer(); + case GetResult3: { + unsigned index = opCode - GetResult0; + LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); + executeGetResult(opCode - GetResult0); break; } - case GetValueType: { - LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); - unsigned memIndex = read(); - Value value = read(); - - LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" - << " * Result: " << value.getType() << "\n\n"); - memory[memIndex] = value.getType().getAsOpaquePointer(); + case GetResultN: + LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); + executeGetResult(read()); break; - } - case IsNotNull: { - LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); - const void *value = read(); - - LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n\n"); - selectJump(value != nullptr); + case GetValueType: + executeGetValueType(); break; - } - case RecordMatch: { - LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); + case IsNotNull: + executeIsNotNull(); + break; + case RecordMatch: assert(matches && "expected matches to be provided when executing the matcher"); - unsigned patternIndex = read(); - PatternBenefit benefit = currentPatternBenefits[patternIndex]; - const ByteCodeField *dest = &code[read()]; - - // If the benefit of the pattern is impossible, skip the processing of the - // rest of the pattern. - if (benefit.isImpossibleToMatch()) { - LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n\n"); - curCodeIt = dest; - break; - } - - // Create a fused location containing the locations of each of the - // operations used in the match. This will be used as the location for - // created operations during the rewrite that don't already have an - // explicit location set. - unsigned numMatchLocs = read(); - SmallVector matchLocs; - matchLocs.reserve(numMatchLocs); - for (unsigned i = 0; i != numMatchLocs; ++i) - matchLocs.push_back(read()->getLoc()); - Location matchLoc = rewriter.getFusedLoc(matchLocs); - - LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" - << " * Location: " << matchLoc << "\n\n"); - matches->emplace_back(matchLoc, patterns[patternIndex], benefit); - readList(matches->back().values); - curCodeIt = dest; + executeRecordMatch(rewriter, *matches); break; - } - case ReplaceOp: { - LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); - Operation *op = read(); - SmallVector args; - readList(args); - - LLVM_DEBUG({ - llvm::dbgs() << " * Operation: " << *op << "\n" - << " * Values: "; - llvm::interleaveComma(args, llvm::dbgs()); - llvm::dbgs() << "\n\n"; - }); - rewriter.replaceOp(op, args); + case ReplaceOp: + executeReplaceOp(rewriter); break; - } - case SwitchAttribute: { - LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); - Attribute value = read(); - ArrayAttr cases = read(); - handleSwitch(value, cases); + case SwitchAttribute: + executeSwitchAttribute(); break; - } - case SwitchOperandCount: { - LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); - Operation *op = read(); - auto cases = read().getValues(); - - LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); - handleSwitch(op->getNumOperands(), cases); + case SwitchOperandCount: + executeSwitchOperandCount(); break; - } - case SwitchOperationName: { - LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); - OperationName value = read()->getName(); - size_t caseCount = read(); - - // The operation names are stored in-line, so to print them out for - // debugging purposes we need to read the array before executing the - // switch so that we can display all of the possible values. - LLVM_DEBUG({ - const ByteCodeField *prevCodeIt = curCodeIt; - llvm::dbgs() << " * Value: " << value << "\n" - << " * Cases: "; - llvm::interleaveComma( - llvm::map_range(llvm::seq(0, caseCount), - [&](size_t i) { return read(); }), - llvm::dbgs()); - llvm::dbgs() << "\n\n"; - curCodeIt = prevCodeIt; - }); - - // Try to find the switch value within any of the cases. - size_t jumpDest = 0; - for (size_t i = 0; i != caseCount; ++i) { - if (read() == value) { - curCodeIt += (caseCount - i - 1); - jumpDest = i + 1; - break; - } - } - selectJump(jumpDest); + case SwitchOperationName: + executeSwitchOperationName(); break; - } - case SwitchResultCount: { - LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); - Operation *op = read(); - auto cases = read().getValues(); - - LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); - handleSwitch(op->getNumResults(), cases); + case SwitchResultCount: + executeSwitchResultCount(); break; - } - case SwitchType: { - LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); - Type value = read(); - auto cases = read().getAsValueRange(); - handleSwitch(value, cases); + case SwitchType: + executeSwitchType(); break; } - } + LLVM_DEBUG(llvm::dbgs() << "\n"); } }