diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -114,6 +114,7 @@ #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/AsmState.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -338,6 +339,41 @@ return getInPlace(v.cast()); } +//===----------------------------------------------------------------------===// +// Printing helpers. +//===----------------------------------------------------------------------===// + +/// Return string containing the essential bufferization information for the +/// given value. +static std::string printValueInfo(Value value) { + std::string result; + llvm::raw_string_ostream os(result); + AsmState state(value.getParentRegion()->getParentOfType()); + auto printTensorOrBufferInfo = [&](std::string prefix, Value value) { + if (!value.getType().isa()) + return; + os << prefix; + value.printAsOperand(os, state); + os << " : " << value.getType(); + if (getInPlace(value) == InPlaceSpec::None) + return; + os << " [InPlace=" << stringify(getInPlace(value)) << "]"; + }; + auto *op = value.getDefiningOp(); + if (op) { + os << op->getName(); + SmallVector shapedOperands; + for (Value value : op->getOperands()) + printTensorOrBufferInfo("\n\t-> ", value); + for (Value result : op->getOpResults()) + printTensorOrBufferInfo("\n\t<- ", result); + } else { + os << value; + printTensorOrBufferInfo("\n\t - ", value); + } + return result; +}; + //===----------------------------------------------------------------------===// // Op-specific semantics helper to retrieve matching inplaceable result. // These should become proper interfaces interfaces when the time is right. @@ -1030,9 +1066,9 @@ os << "\n/========================== AliasInfo " "==========================\n"; for (auto it : aliasInfo) { - os << "|\n| -- source: " << it.getFirst() << '\n'; + os << "|\n| -- source: " << printValueInfo(it.getFirst()) << '\n'; for (auto v : it.getSecond()) - os << "| ---- target: " << v << '\n'; + os << "| ---- target: " << printValueInfo(v) << '\n'; } os << "|\n\\====================== End AliasInfo " "======================\n\n"; @@ -1042,12 +1078,12 @@ if (!it->isLeader()) continue; Value leader = it->getData(); - os << "|\n| -- leader: " << leader << '\n'; + os << "|\n| -- leader: " << printValueInfo(leader) << '\n'; for (auto mit = equivalentInfo.member_begin(it), meit = equivalentInfo.member_end(); mit != meit; ++mit) { Value v = static_cast(*mit); - os << "| ---- equivalent member: " << v << '\n'; + os << "| ---- equivalent member: " << printValueInfo(v) << '\n'; } } os << "|\n\\***************** End Equivalent Buffers *****************\n\n"; @@ -2655,7 +2691,7 @@ SmallVector orderedFuncOps; DenseMap> callerMap; auto res = getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap); - (void) res; + (void)res; assert(succeeded(res) && "unexpected getFuncOpsOrderedByCalls failure"); for (FuncOp funcOp : orderedFuncOps) {