diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -114,6 +114,7 @@ #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/AsmState.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -127,6 +128,7 @@ #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/FormatVariadic.h" #define DEBUG_TYPE "comprehensive-func-bufferize" @@ -137,6 +139,10 @@ #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") #define LDBG(X) LLVM_DEBUG(DBGS() << X) +// Forward declarations. +static std::string printOperationInfo(Operation *); +static std::string printValueInfo(Value); + //===----------------------------------------------------------------------===// // Generic helpers. //===----------------------------------------------------------------------===// @@ -173,13 +179,15 @@ /// Wrapper for better debugging. static void map(BlockAndValueMapping &bvm, ValueRange keys, ValueRange values) { assert(!keys.empty() && "Unexpected empty keys"); - LDBG("Map: " << keys.front() << " to " << values.front() << '\n'); + LDBG("\n\tMap: " << printValueInfo(keys.front()) + << "\n\tto: " << printValueInfo(values.front()) << '\n'); return bvm.map(keys, values); } /// Wrapper for better debugging. static void map(BlockAndValueMapping &bvm, Value key, Value value) { - LDBG("Map: " << key << " to " << value << '\n'); + LDBG("\n\tMap: " << printValueInfo(key) << "\n\tto: " << printValueInfo(value) + << '\n'); return bvm.map(key, value); } @@ -260,8 +268,9 @@ llvm::to_vector<4>(attr.getAsValueRange())) : SmallVector(op->getNumResults(), stringify(InPlaceSpec::None)); - LDBG("->set inPlace=" << stringify(inPlace) << ": " << *op - << " @idx=" << opResult.getResultNumber() << '\n'); + LDBG("->set inPlace=" << stringify(inPlace) << " <- #" + << opResult.getResultNumber() << ": " + << printOperationInfo(op) << "\n"); inPlaceVector[opResult.getResultNumber()] = stringify(inPlace); op->setAttr(kInPlaceResultsAttrName, OpBuilder(op).getStrArrayAttr(inPlaceVector)); @@ -338,6 +347,57 @@ return getInPlace(v.cast()); } +//===----------------------------------------------------------------------===// +// Printing helpers. +//===----------------------------------------------------------------------===// + +/// Helper method printing the bufferization information of a buffer / tensor. +static void printTensorOrBufferInfo(std::string prefix, Value value, + AsmState &state, llvm::raw_ostream &os) { + if (!value.getType().isa()) + return; + os << prefix; + value.printAsOperand(os, state); + os << " : " << value.getType(); + if (getInPlace(value) == InPlaceSpec::None) + return; + os << " [InPlace=" << stringify(getInPlace(value)) << "]"; +} + +/// Print the operation name and bufferization information. +static std::string printOperationInfo(Operation *op) { + std::string result; + llvm::raw_string_ostream os(result); + AsmState state(op->getParentOfType()); + os << op->getName(); + SmallVector shapedOperands; + for (OpOperand &opOperand : op->getOpOperands()) { + std::string prefix = + llvm::formatv("\n\t-> #{0} ", opOperand.getOperandNumber()); + printTensorOrBufferInfo(prefix, opOperand.get(), state, os); + } + for (OpResult opResult : op->getOpResults()) { + std::string prefix = + llvm::formatv("\n\t<- #{0} ", opResult.getResultNumber()); + printTensorOrBufferInfo(prefix, opResult, state, os); + } + return result; +} + +/// Print the bufferization information for the defining op or block argument. +static std::string printValueInfo(Value value) { + auto *op = value.getDefiningOp(); + if (op) + return printOperationInfo(op); + // Print the block argument bufferization information. + std::string result; + llvm::raw_string_ostream os(result); + AsmState state(value.getParentRegion()->getParentOfType()); + os << value; + printTensorOrBufferInfo("\n\t - ", value, state, os); + return result; +} + //===----------------------------------------------------------------------===// // Op-specific semantics helper to retrieve matching inplaceable result. // These should become proper interfaces interfaces when the time is right. @@ -843,13 +903,14 @@ bool BufferizationAliasInfo::aliasesNonWriteableBuffer( OpOperand &operand) const { LDBG("----Start aliasesNonWriteableBuffer\n"); - LDBG("-------for operand #" << operand.getOperandNumber() << ": " - << *(operand.getOwner()) << '\n'); + LDBG("-------for -> #" << operand.getOperandNumber() << ": " + << printOperationInfo(operand.getOwner()) << '\n'); for (Value v : getAliasInfoRef(operand.get())) { - LDBG("-----------examine: " << v << '\n'); + LDBG("-----------examine: " << printValueInfo(v) << '\n'); if (auto bbArg = v.dyn_cast()) { if (getInPlace(bbArg) == InPlaceSpec::True) { - LDBG("-----------bbArg is writeable -> skip: " << bbArg << '\n'); + LDBG("-----------bbArg is writeable -> skip: " << printValueInfo(bbArg) + << '\n'); continue; } LDBG("-----------notWriteable\n"); @@ -871,12 +932,12 @@ /// to some buffer write. bool BufferizationAliasInfo::aliasesInPlaceWrite(Value value) const { LDBG("----Start aliasesInPlaceWrite\n"); - LDBG("-------for : " << value << '\n'); + LDBG("-------for : " << printValueInfo(value) << '\n'); for (Value v : getAliasInfoRef(value)) { for (auto &use : v.getUses()) { if (bufferizesToMemoryWrite(use, InPlaceSpec::True)) { LDBG("-----------wants to bufferize to inPlace write: " - << *use.getOwner() << '\n'); + << printOperationInfo(use.getOwner()) << '\n'); return true; } } @@ -923,7 +984,7 @@ Operation *opToBufferize = result.getDefiningOp(); Value root = (*maybeAliasingOperand)->get(); LDBG("----Start wouldCreateReadAfterWriteInterference\n"); - LDBG("--------rootValue: " << root << "\n"); + LDBG("--------rootValue: " << printValueInfo(root) << "\n"); // Collect: // 1. all the inplace write uses of some alias of `root`. @@ -961,8 +1022,9 @@ for (OpOperand *uRead : usesRead) { Operation *aliasingReadOp = uRead->getOwner(); - LDBG("----++++aliasRead #" << uRead->getOperandNumber() - << " in: " << *aliasingReadOp << '\n'); + LDBG("----++++aliasRead -> #" + << uRead->getOperandNumber() + << " in: " << printOperationInfo(aliasingReadOp) << '\n'); for (OpOperand *uWrite : usesWrite) { // Don't consider self-use of the same operand for interference. // Multiple different uses within the same op is fair game though. @@ -970,8 +1032,9 @@ continue; Operation *aliasingWriteOp = uWrite->getOwner(); - LDBG("---- aliasWrite #" << uWrite->getOperandNumber() - << " in: " << *aliasingWriteOp << '\n'); + LDBG("---- aliasWrite -> #" + << uWrite->getOperandNumber() + << " in: " << printOperationInfo(aliasingWriteOp) << '\n'); // If the candidate write is the one that produces the read value (in the // SSA def-use sense), this is not considered an interference. if (getInplaceableOpResult(*uWrite) == uRead->get()) @@ -983,10 +1046,12 @@ // At this point, aliasingWriteOp properly dominates aliasingReadOp or // there is no clear dominance and we need to be conservative. LDBG("---->found RaW interference\n"); - LDBG(" Interfering read (op #" << uRead->getOperandNumber() - << "): " << *aliasingReadOp << '\n'); - LDBG(" Interfering write (op #" << uWrite->getOperandNumber() - << "): " << *aliasingWriteOp << '\n'); + LDBG(" Interfering read -> #" << uRead->getOperandNumber() << ":\n" + << printOperationInfo(aliasingReadOp) + << '\n'); + LDBG(" Interfering write -> #" << uWrite->getOperandNumber() << ":\n" + << printOperationInfo(aliasingWriteOp) + << '\n'); LDBG("---->opportunity to clobber RaW interference\n"); if (isClobberedWriteBeforeRead(opToBufferize, *uRead, *uWrite, domInfo)) { LDBG("---->clobbered! -> skip\n"); @@ -1037,9 +1102,9 @@ os << "\n/========================== AliasInfo " "==========================\n"; for (auto it : aliasInfo) { - os << "|\n| -- source: " << it.getFirst() << '\n'; + os << "|\n| -- source: " << printValueInfo(it.getFirst()) << '\n'; for (auto v : it.getSecond()) - os << "| ---- target: " << v << '\n'; + os << "| ---- target: " << printValueInfo(v) << '\n'; } os << "|\n\\====================== End AliasInfo " "======================\n\n"; @@ -1049,12 +1114,12 @@ if (!it->isLeader()) continue; Value leader = it->getData(); - os << "|\n| -- leader: " << leader << '\n'; + os << "|\n| -- leader: " << printValueInfo(leader) << '\n'; for (auto mit = equivalentInfo.member_begin(it), meit = equivalentInfo.member_end(); mit != meit; ++mit) { Value v = static_cast(*mit); - os << "| ---- equivalent member: " << v << '\n'; + os << "| ---- equivalent member: " << printValueInfo(v) << '\n'; } } os << "|\n\\***************** End Equivalent Buffers *****************\n\n"; @@ -1136,7 +1201,8 @@ Operation *candidateOp = mit->v.getDefiningOp(); if (!candidateOp) continue; - LDBG("---->clobbering candidate: " << *candidateOp << '\n'); + LDBG("---->clobbering candidate: " << printOperationInfo(candidateOp) + << '\n'); if (domInfo.properlyDominates(aliasingWriteOp, candidateOp) && domInfo.properlyDominates(candidateOp, aliasingReadOp)) return true; @@ -2165,7 +2231,8 @@ BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo) { LDBG('\n'); - LDBG("Inplace analysis for extract_slice: " << *extractSliceOp << '\n'); + LDBG("Inplace analysis for extract_slice: " + << printOperationInfo(extractSliceOp) << '\n'); // If `extractSliceOp` were to be bufferized inplace, it cannot end up // aliasing a write into a non-writeable buffer. @@ -2210,9 +2277,9 @@ int64_t resultNumber = result.getResultNumber(); (void)resultNumber; LDBG('\n'); - LDBG("Inplace analysis for result #" << resultNumber << " (operand #" - << operand.getOperandNumber() << ") in " - << result << '\n'); + LDBG("Inplace analysis for <- #" << resultNumber << " -> #" + << operand.getOperandNumber() << " in " + << printValueInfo(result) << '\n'); // `result` must bufferize to a writeable buffer to be a candidate. // This means the operand must not alias either: