diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -140,6 +140,10 @@ /// checking the results of the analysis) and post analysis steps. bool testAnalysisOnly = false; + /// If set to `true`, the IR is annotated with details about RaW conflicts. + /// For debugging only. Should be used together with `testAnalysisOnly`. + bool printConflicts = false; + /// Registered post analysis steps. PostAnalysisStepList postAnalysisSteps; diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -39,6 +39,9 @@ Option<"testAnalysisOnly", "test-analysis-only", "bool", /*default=*/"false", "Only runs inplaceability analysis (for testing purposes only)">, + Option<"printConflicts", "print-conflicts", "bool", + /*default=*/"false", + "Annotates IR with RaW conflicts. Requires test-analysis-only.">, Option<"allowReturnMemref", "allow-return-memref", "bool", /*default=*/"false", "Allows the return of memrefs (for testing purposes only)">, diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -117,25 +117,12 @@ #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SetVector.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/FormatVariadic.h" - -#define DEBUG_TYPE "comprehensive-module-bufferize" using namespace mlir; using namespace linalg; using namespace tensor; using namespace comprehensive_bufferize; -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X) - -// Forward declarations. -#ifndef NDEBUG -static std::string printOperationInfo(Operation *, bool prefix = true); -static std::string printValueInfo(Value, bool prefix = true); -#endif - static bool isaTensor(Type t) { return t.isa(); } //===----------------------------------------------------------------------===// @@ -164,64 +151,11 @@ attr ? SmallVector( llvm::to_vector<4>(attr.getAsValueRange())) : SmallVector(op->getNumResults(), "false"); - LDBG("->set inPlace=" << inPlace << " <- #" << opResult.getResultNumber() - << ": " << printOperationInfo(op) << "\n"); inPlaceVector[opResult.getResultNumber()] = inPlace ? "true" : "false"; op->setAttr(kInPlaceResultsAttrName, OpBuilder(op).getStrArrayAttr(inPlaceVector)); } -//===----------------------------------------------------------------------===// -// Printing helpers. -//===----------------------------------------------------------------------===// - -#ifndef NDEBUG -/// Helper method printing the bufferization information of a buffer / tensor. -static void printTensorOrBufferInfo(std::string prefix, Value value, - AsmState &state, llvm::raw_ostream &os) { - if (!value.getType().isa()) - return; - os << prefix; - value.printAsOperand(os, state); - os << " : " << value.getType(); -} - -/// Print the operation name and bufferization information. -static std::string printOperationInfo(Operation *op, bool prefix) { - std::string result; - llvm::raw_string_ostream os(result); - AsmState state(op->getParentOfType()); - StringRef tab = prefix ? "\n[" DEBUG_TYPE "]\t" : ""; - os << tab << op->getName(); - SmallVector shapedOperands; - for (OpOperand &opOperand : op->getOpOperands()) { - std::string prefix = - llvm::formatv("{0} -> #{1} ", tab, opOperand.getOperandNumber()); - printTensorOrBufferInfo(prefix, opOperand.get(), state, os); - } - for (OpResult opResult : op->getOpResults()) { - std::string prefix = - llvm::formatv("{0} <- #{1} ", tab, opResult.getResultNumber()); - printTensorOrBufferInfo(prefix, opResult, state, os); - } - return result; -} - -/// Print the bufferization information for the defining op or block argument. -static std::string printValueInfo(Value value, bool prefix) { - auto *op = value.getDefiningOp(); - if (op) - return printOperationInfo(op, prefix); - // Print the block argument bufferization information. - std::string result; - llvm::raw_string_ostream os(result); - AsmState state(value.getParentRegion()->getParentOfType()); - os << value; - printTensorOrBufferInfo("\n\t - ", value, state, os); - return result; -} -#endif - //===----------------------------------------------------------------------===// // Bufferization-specific alias analysis. //===----------------------------------------------------------------------===// @@ -251,7 +185,6 @@ static bool aliasesNonWritableBuffer(Value value, const BufferizationAliasInfo &aliasInfo, BufferizationState &state) { - LDBG("WRITABILITY ANALYSIS FOR " << printValueInfo(value) << "\n"); bool foundNonWritableBuffer = false; aliasInfo.applyOnAliases(value, [&](Value v) { // Query BufferizableOpInterface to see if the OpResult is writable. @@ -270,11 +203,6 @@ foundNonWritableBuffer = true; }); - if (foundNonWritableBuffer) - LDBG("--> NON WRITABLE\n"); - else - LDBG("--> WRITABLE\n"); - return foundNonWritableBuffer; } @@ -282,23 +210,15 @@ /// to some buffer write. static bool aliasesInPlaceWrite(Value value, const BufferizationAliasInfo &aliasInfo) { - LDBG("----Start aliasesInPlaceWrite\n"); - LDBG("-------for : " << printValueInfo(value) << '\n'); bool foundInplaceWrite = false; aliasInfo.applyOnAliases(value, [&](Value v) { for (auto &use : v.getUses()) { if (isInplaceMemoryWrite(use, aliasInfo)) { - LDBG("-----------wants to bufferize to inPlace write: " - << printOperationInfo(use.getOwner()) << '\n'); foundInplaceWrite = true; return; } } }); - - if (!foundInplaceWrite) - LDBG("----------->does not alias an inplace write\n"); - return foundInplaceWrite; } @@ -317,6 +237,39 @@ return false; } +/// Annotate IR with details about the detected RaW conflict. +static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite, + Value lastWrite) { + static uint64_t counter = 0; + Operation *readingOp = uRead->getOwner(); + Operation *conflictingWritingOp = uConflictingWrite->getOwner(); + + OpBuilder b(conflictingWritingOp->getContext()); + std::string id = "C_" + std::to_string(counter++); + + std::string conflictingWriteAttr = + id + + "[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) + + "]"; + conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr()); + + std::string readAttr = + id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]"; + readingOp->setAttr(readAttr, b.getUnitAttr()); + + if (auto opResult = lastWrite.dyn_cast()) { + std::string lastWriteAttr = id + "[LAST-WRITE: result " + + std::to_string(opResult.getResultNumber()) + + "]"; + opResult.getDefiningOp()->setAttr(lastWriteAttr, b.getUnitAttr()); + } else { + auto bbArg = lastWrite.cast(); + std::string lastWriteAttr = + id + "[LAST-WRITE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]"; + bbArg.getOwner()->getParentOp()->setAttr(lastWriteAttr, b.getUnitAttr()); + } +} + /// Given sets of uses and writes, return true if there is a RaW conflict under /// the assumption that all given reads/writes alias the same buffer and that /// all given writes bufferize inplace. @@ -351,14 +304,6 @@ // met for uConflictingWrite to be an actual conflict. Operation *conflictingWritingOp = uConflictingWrite->getOwner(); - // Print some debug info. - LDBG("Found potential conflict:\n"); - LDBG("READ = #" << uRead->getOperandNumber() << " of " - << printOperationInfo(readingOp) << "\n"); - LDBG("CONFLICTING WRITE = #" - << uConflictingWrite->getOperandNumber() << " of " - << printOperationInfo(conflictingWritingOp) << "\n"); - // No conflict if the readingOp dominates conflictingWritingOp, i.e., the // write is not visible when reading. if (happensBefore(readingOp, conflictingWritingOp, domInfo)) @@ -387,8 +332,6 @@ if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) continue; - LDBG("WRITE = #" << printValueInfo(lastWrite) << "\n"); - // No conflict if the conflicting write happens before the last // write. if (Operation *writingOp = lastWrite.getDefiningOp()) { @@ -413,12 +356,14 @@ continue; // All requirements are met. Conflict found! - LDBG("CONFLICT CONFIRMED!\n\n"); + + if (options.printConflicts) + annotateConflict(uRead, uConflictingWrite, lastWrite); + return true; } } - LDBG("NOT A CONFLICT!\n\n"); return false; } @@ -530,7 +475,6 @@ if (!hasWrite) return false; - LDBG("->the corresponding buffer is not writeable\n"); return true; } @@ -548,13 +492,6 @@ "operand and result do not match"); #endif // NDEBUG - int64_t resultNumber = result.getResultNumber(); - (void)resultNumber; - LDBG('\n'); - LDBG("Inplace analysis for <- #" << resultNumber << " -> #" - << operand.getOperandNumber() << " in " - << printValueInfo(result) << '\n'); - bool foundInterference = wouldCreateWriteToNonWritableBuffer(operand, result, aliasInfo, state) || wouldCreateReadAfterWriteInterference(operand, result, domInfo, state, @@ -565,8 +502,6 @@ else aliasInfo.bufferizeInPlace(result, operand); - LDBG("Done inplace analysis for result #" << resultNumber << '\n'); - return success(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -14,12 +14,6 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Operation.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/FormatVariadic.h" - -#define DEBUG_TYPE "comprehensive-module-bufferize" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X) using namespace mlir; using namespace linalg; @@ -181,7 +175,6 @@ auto it2 = bufferizedFunctionTypes.try_emplace( funcOp, getBufferizedFunctionType(funcOp.getContext(), argumentTypes, resultTypes)); - LDBG("FT: " << funcOp.getType() << " -> " << it2.first->second << "\n"); return it2.first->second; } @@ -227,7 +220,6 @@ /// future. static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp, BufferizationState &state) { - LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n"); ModuleBufferizationState &moduleState = getModuleBufferizationState(state); // If nothing to do then we are done. @@ -261,7 +253,6 @@ funcOp, funcOp.getType().getInputs(), TypeRange{}, moduleState.bufferizedFunctionTypes); funcOp.setType(bufferizedFuncType); - LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp); return success(); } @@ -341,8 +332,6 @@ // 4. Rewrite the FuncOp type to buffer form. funcOp.setType(bufferizedFuncType); - LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary:\n" << funcOp); - return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -89,6 +89,7 @@ options.allowUnknownOps = allowUnknownOps; options.analysisFuzzerSeed = analysisFuzzerSeed; options.testAnalysisOnly = testAnalysisOnly; + options.printConflicts = printConflicts; // Enable InitTensorOp elimination. options.addPostAnalysisStep<