diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h @@ -19,6 +19,13 @@ namespace comprehensive_bufferize { namespace scf_ext { +/// Equivalence analysis for scf.for. Raise an error if iter_args are not +/// equivalent to their corresponding loop yield values. +struct AssertDestinationPassingStyle : public PostAnalysisStep { + LogicalResult run(FuncOp funcOp, BufferizationState &state, + SmallVector &newOps) override; +}; + void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); } // namespace scf_ext diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp @@ -37,7 +37,6 @@ auto globalMemref = globalCreator.getGlobalFor(constantOp); Value memref = b.create( constantOp.getLoc(), globalMemref.type(), globalMemref.getName()); - state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult()); state.mapBuffer(constantOp, memref); return success(); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -141,22 +141,7 @@ /// Return `true` if a value was marked as in-place bufferized. bool BufferizationAliasInfo::isInPlace(OpResult opResult) const { - bool inplace = inplaceBufferized.contains(opResult); -#ifndef NDEBUG - if (inplace) { - auto bufferizableOp = - dyn_cast(opResult.getDefiningOp()); - assert(bufferizableOp && - "expected that in-place bufferized op is bufferizable"); - SmallVector operands = - bufferizableOp.getAliasingOpOperand(opResult); - for (OpOperand *operand : operands) - assert(areAliasingBufferizedValues(operand->get(), opResult) && - "expected that in-place bufferized OpResult aliases with " - "aliasing OpOperand"); - } -#endif // NDEBUG - return inplace; + return inplaceBufferized.contains(opResult); } /// Set the inPlace bufferization spec to true. @@ -593,7 +578,6 @@ Value casted = allocated.getValue(); if (memRefType && memRefType != allocMemRefType) { casted = b.create(loc, memRefType, allocated.getValue()); - aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue()); } // 2. Create memory deallocation. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -253,8 +253,6 @@ return failure(); // Insert mapping and aliasing info. - state.aliasInfo.createAliasInfoEntry(resultBuffer); - state.aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer); state.mapBuffer(opResult, resultBuffer); // Insert new operand and bbArg. @@ -263,9 +261,6 @@ body->insertArgument(nextOutputBBArgIndex, resultBuffer.getType()); BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex); // Insert mapping and aliasing info. - state.aliasInfo.createAliasInfoEntry(newBufferBBArg); - state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, - newBufferBBArg); state.mapBuffer(oldTensorBBArg, newBufferBBArg); // Set operand of `linalg.yield` to the bbArg so it just canonicalizes @@ -303,9 +298,6 @@ BlockArgument oldTensorBBArg = body->getArgument(oldInputBBArgIndex); // Insert mapping and aliasing info. - state.aliasInfo.createAliasInfoEntry(newBufferBBArg); - state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, - newBufferBBArg); state.mapBuffer(oldTensorBBArg, newBufferBBArg); // Increment indices. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -223,7 +223,6 @@ BufferizationState &state) { LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n"); ModuleBufferizationState &moduleState = getModuleBufferizationState(state); - BufferizationAliasInfo &aliasInfo = state.aliasInfo; // If nothing to do then we are done. if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) && @@ -321,15 +320,12 @@ auto castOp = b.create( funcOp.getLoc(), toMemrefOp.memref().getType(), memref); toMemrefOp.memref().replaceAllUsesWith(castOp); - aliasInfo.insertNewBufferEquivalence(castOp.dest(), - toMemrefOp.memref()); } } // Replace all remaining uses by a to_tensor. if (!bbArg.use_empty()) { auto toTensorOp = b.create(funcOp.getLoc(), memref); - aliasInfo.insertNewBufferEquivalence(toTensorOp, bbArg); bbArg.replaceAllUsesWith(toTensorOp); } frontBlock.eraseArgument(0); @@ -562,7 +558,6 @@ Value buffer = state.lookupBuffer(callOp->getOperand(idx)); // Add CallOp operand/result equivalence: this is interprocedural // info. - state.aliasInfo.insertNewBufferEquivalence(oldRes, buffer); state.mapBuffer(oldRes, buffer); // Add a ToTensorOp to kill all uses of the CallOp return. // Replace all uses of the CallOp results so we can erase the CallOp. @@ -572,7 +567,6 @@ b.create(callOp.getLoc(), buffer); oldRes.replaceAllUsesWith(toTensorOp); // Add new op equivalence info. - state.aliasInfo.insertNewBufferEquivalence(toTensorOp, buffer); state.mapBuffer(toTensorOp, buffer); continue; } @@ -615,7 +609,6 @@ Value castBuffer = b.create(callOp.getLoc(), memRefType, buffer); // Add new op equivalence info. - state.aliasInfo.insertNewBufferEquivalence(castBuffer, buffer); state.mapBuffer(tensorOperand, castBuffer); buffer = castBuffer; } @@ -663,7 +656,6 @@ Value returnTensor = b.create( returnOp.getLoc(), v); operand.set(returnTensor); - state.aliasInfo.insertNewBufferEquivalence(returnTensor, v); state.mapBuffer(returnTensor, v); } return success(); @@ -690,7 +682,6 @@ : getContiguousOrUnrankedMemRefType(tensorType); Value bufferCast = b.create(funcOp.getLoc(), memRefType, bbArg); - state.aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg); state.mapBuffer(bbArg, bufferCast); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -147,7 +147,6 @@ if (!resultBuffer) return failure(); - state.aliasInfo.createAliasInfoEntry(resultBuffer); state.mapBuffer(opResult, resultBuffer); } @@ -237,8 +236,6 @@ OpOperand &opOperand = forOp.getOpOperandForResult(opResult); BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand); - state.aliasInfo.createAliasInfoEntry(resultBuffer); - state.aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer); state.mapBuffer(bbArg, resultBuffer); state.mapBuffer(opResult, resultBuffer); } @@ -257,15 +254,6 @@ OpOperand &forOperand = forOp.getOpOperandForResult( forOp->getResult(operand.getOperandNumber())); auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); - if (!state.aliasInfo.areEquivalentBufferizedValues(operand.get(), - bbArg)) { - // TODO: this could get resolved with copies but it can also turn into - // swaps so we need to be careful about order of copies. - return yieldOp->emitError() - << "Yield operand #" << operand.getOperandNumber() - << " does not bufferize to an equivalent buffer to the matching" - << " enclosing scf::for operand"; - } // Buffers are equivalent so the work is already done and we just yield // the bbArg so that it later canonicalizes away. @@ -275,6 +263,41 @@ } }; +LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext:: + AssertDestinationPassingStyle::run(FuncOp funcOp, BufferizationState &state, + SmallVector &newOps) { + LogicalResult status = success(); + funcOp->walk([&](scf::YieldOp yieldOp) { + auto forOp = dyn_cast(yieldOp->getParentOp()); + if (!forOp) + return WalkResult::advance(); + + for (OpOperand &operand : yieldOp->getOpOperands()) { + auto tensorType = operand.get().getType().dyn_cast(); + if (!tensorType) + continue; + + OpOperand &forOperand = forOp.getOpOperandForResult( + forOp->getResult(operand.getOperandNumber())); + auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); + if (!state.aliasInfo.areEquivalentBufferizedValues(operand.get(), + bbArg)) { + // TODO: this could get resolved with copies but it can also turn into + // swaps so we need to be careful about order of copies. + status = + yieldOp->emitError() + << "Yield operand #" << operand.getOperandNumber() + << " does not bufferize to an equivalent buffer to the matching" + << " enclosing scf::for operand"; + return WalkResult::interrupt(); + } + } + + return WalkResult::advance(); + }); + return status; +} + struct YieldOpInterface : public BufferizableOpInterface::ExternalModel { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -80,7 +80,6 @@ castOp.getResult().getType(), layout, memorySpace); Value res = b.create(castOp.getLoc(), memRefType, resultBuffer); - state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult()); state.mapBuffer(castOp.getResult(), res); return success(); } @@ -233,7 +232,6 @@ b.create(loc, insertOp.scalar(), destMemref, insertOp.indices()); state.mapBuffer(insertOp, destMemref); - state.aliasInfo.insertNewBufferAlias(insertOp, destMemref); return success(); } @@ -421,8 +419,6 @@ Value subView = b.create( loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); - // Insert new alias. - state.aliasInfo.insertNewBufferAlias(subView, dstMemref); // Copy tensor. Value srcMemref = state.lookupBuffer(insertSliceOp.source()); state.options.allocationFns->memCpyFn(b, insertSliceOp.getLoc(), diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -96,6 +96,7 @@ // TODO: Find a way to enable this step automatically when bufferizing tensor // dialect ops. options.addPostAnalysisStep(); + options.addPostAnalysisStep(); ModuleOp moduleOp = getOperation(); applyEnablingTransformations(moduleOp); diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -1113,7 +1113,7 @@ // Read from %t1 via alias %e. %v2 = vector.transfer_read %e[%s], %cst : tensor, vector<5xf32> - scf.yield %e, %v2 : tensor, vector<5xf32> + scf.yield %t2, %v2 : tensor, vector<5xf32> } // CHECK: __inplace_results_attr__ = ["true", "false"] @@ -1154,14 +1154,10 @@ // This loop does not read from %t1. It only writes to it. // CHECK: scf.for %r, %v3 = scf.for %i = %c0 to %s step %c1 iter_args(%t2 = %t1, %v0 = %v) -> (tensor, vector<5xf32>) { - // CHECK: tensor.extract_slice - // CHECK-SAME: __inplace_results_attr__ = ["true"] - %e = tensor.extract_slice %t2[%s][%s][1] : tensor to tensor - - // Write to %t1 via alias. (Overwrite %t3.) + // Write to %t1 via %t2. (Overwrite %t3.) // CHECK: linalg.generic // CHECK-SAME: __inplace_results_attr__ = ["true"] - %o2 = linalg.generic #trait outs (%e : tensor) { + %o2 = linalg.generic #trait outs (%t2 : tensor) { ^bb(%0: f32) : linalg.yield %cst : f32 } -> (tensor) @@ -1172,8 +1168,8 @@ } // Use %t3 in some way without reading it, so that it does not get DCE'd. - // CHECK: linalg.generic - // CHECK-SAME: __inplace_results_attr__ = ["true"] + // CHECK: linalg.generic + // CHECK-SAME: __inplace_results_attr__ = ["true"] %o = linalg.generic #trait outs (%t3 : tensor) { ^bb(%0: f32) : linalg.yield %cst : f32