Differential D115491 Diff 393462 mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
Show First 20 Lines • Show All 125 Lines • ▼ Show 20 Lines | bool mustBufferizeInPlace(Operation *op, OpResult opResult) const { | ||||
// they are mostly ignored by the analysis once alias sets are set up. | // they are mostly ignored by the analysis once alias sets are set up. | ||||
return true; | return true; | ||||
} | } | ||||
LogicalResult bufferize(Operation *op, OpBuilder &b, | LogicalResult bufferize(Operation *op, OpBuilder &b, | ||||
BufferizationState &state) const { | BufferizationState &state) const { | ||||
auto ifOp = cast<scf::IfOp>(op); | auto ifOp = cast<scf::IfOp>(op); | ||||
// Bufferize then/else blocks. | // Use IRRewriter instead of OpBuilder because it has additional helper | ||||
if (failed(comprehensive_bufferize::bufferize(ifOp.thenBlock(), state))) | // functions. | ||||
return failure(); | IRRewriter rewriter(op->getContext()); | ||||
if (failed(comprehensive_bufferize::bufferize(ifOp.elseBlock(), state))) | rewriter.setInsertionPoint(ifOp); | ||||
return failure(); | |||||
for (OpResult opResult : ifOp->getResults()) { | // Compute new types of the bufferized scf.if op. | ||||
if (!opResult.getType().isa<TensorType>()) | SmallVector<Type> newTypes; | ||||
continue; | for (Type returnType : ifOp->getResultTypes()) { | ||||
// TODO: Atm we bail on unranked TensorType because we don't know how to | if (returnType.isa<TensorType>()) { | ||||
// alloc an UnrankedMemRefType + its underlying ranked MemRefType. | assert(returnType.isa<RankedTensorType>() && | ||||
assert(opResult.getType().isa<RankedTensorType>() && | |||||
"unsupported unranked tensor"); | "unsupported unranked tensor"); | ||||
newTypes.push_back( | |||||
getDynamicMemRefType(returnType.cast<RankedTensorType>())); | |||||
} else { | |||||
newTypes.push_back(returnType); | |||||
} | |||||
} | |||||
Value resultBuffer = state.getResultBuffer(opResult); | // Create new op. | ||||
if (!resultBuffer) | auto newIfOp = | ||||
return failure(); | rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.condition(), | ||||
/*withElseRegion=*/true); | |||||
state.mapBuffer(opResult, resultBuffer); | // Remove terminators. | ||||
if (!newIfOp.thenBlock()->empty()) { | |||||
rewriter.eraseOp(newIfOp.thenBlock()->getTerminator()); | |||||
rewriter.eraseOp(newIfOp.elseBlock()->getTerminator()); | |||||
} | } | ||||
// Move over then/else blocks. | |||||
rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock()); | |||||
rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock()); | |||||
// Update scf.yield of new then-block. | |||||
auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator()); | |||||
rewriter.setInsertionPoint(thenYieldOp); | |||||
SmallVector<Value> thenYieldValues; | |||||
for (OpOperand &operand : thenYieldOp->getOpOperands()) { | |||||
if (operand.get().getType().isa<TensorType>()) { | |||||
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( | |||||
operand.get().getLoc(), newTypes[operand.getOperandNumber()], | |||||
operand.get()); | |||||
operand.set(toMemrefOp); | |||||
} | |||||
} | |||||
// Update scf.yield of new else-block. | |||||
auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator()); | |||||
rewriter.setInsertionPoint(elseYieldOp); | |||||
SmallVector<Value> elseYieldValues; | |||||
for (OpOperand &operand : elseYieldOp->getOpOperands()) { | |||||
if (operand.get().getType().isa<TensorType>()) { | |||||
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( | |||||
operand.get().getLoc(), newTypes[operand.getOperandNumber()], | |||||
operand.get()); | |||||
operand.set(toMemrefOp); | |||||
} | |||||
} | |||||
// Replace op results. | |||||
state.replaceOp(op, newIfOp->getResults()); | |||||
// Bufferize then/else blocks. | |||||
if (failed(comprehensive_bufferize::bufferize(newIfOp.thenBlock(), state))) | |||||
return failure(); | |||||
if (failed(comprehensive_bufferize::bufferize(newIfOp.elseBlock(), state))) | |||||
return failure(); | |||||
return success(); | return success(); | ||||
} | } | ||||
BufferRelation bufferRelation(Operation *op, OpResult opResult, | BufferRelation bufferRelation(Operation *op, OpResult opResult, | ||||
const BufferizationAliasInfo &aliasInfo) const { | const BufferizationAliasInfo &aliasInfo) const { | ||||
// IfOp results are equivalent to their corresponding yield values if both | // IfOp results are equivalent to their corresponding yield values if both | ||||
// yield values are equivalent to each other. | // yield values are equivalent to each other. | ||||
auto bufferizableOp = cast<BufferizableOpInterface>(op); | auto bufferizableOp = cast<BufferizableOpInterface>(op); | ||||
▲ Show 20 Lines • Show All 125 Lines • ▼ Show 20 Lines | LogicalResult bufferize(Operation *op, OpBuilder & /*b*/, | ||||
// Bufferize loop body. | // Bufferize loop body. | ||||
if (failed(comprehensive_bufferize::bufferize(loopBody, state))) | if (failed(comprehensive_bufferize::bufferize(loopBody, state))) | ||||
return failure(); | return failure(); | ||||
return success(); | return success(); | ||||
} | } | ||||
}; | }; | ||||
// TODO: Evolve toward matching ReturnLike ops. Check for aliasing values that | |||||
// do not bufferize inplace. (Requires a few more changes for ConstantOp, | |||||
// InitTensorOp, CallOp.) | |||||
LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext:: | LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext:: | ||||
AssertDestinationPassingStyle::run(Operation *op, BufferizationState &state, | AssertDestinationPassingStyle::run(Operation *op, BufferizationState &state, | ||||
BufferizationAliasInfo &aliasInfo, | BufferizationAliasInfo &aliasInfo, | ||||
SmallVector<Operation *> &newOps) { | SmallVector<Operation *> &newOps) { | ||||
LogicalResult status = success(); | LogicalResult status = success(); | ||||
op->walk([&](scf::YieldOp yieldOp) { | op->walk([&](scf::YieldOp yieldOp) { | ||||
auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp()); | if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) { | ||||
if (!forOp) | |||||
return WalkResult::advance(); | |||||
for (OpOperand &operand : yieldOp->getOpOperands()) { | for (OpOperand &operand : yieldOp->getOpOperands()) { | ||||
auto tensorType = operand.get().getType().dyn_cast<TensorType>(); | auto tensorType = operand.get().getType().dyn_cast<TensorType>(); | ||||
if (!tensorType) | if (!tensorType) | ||||
continue; | continue; | ||||
OpOperand &forOperand = forOp.getOpOperandForResult( | OpOperand &forOperand = forOp.getOpOperandForResult( | ||||
forOp->getResult(operand.getOperandNumber())); | forOp->getResult(operand.getOperandNumber())); | ||||
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); | auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); | ||||
if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) { | if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) { | ||||
// TODO: this could get resolved with copies but it can also turn into | // TODO: this could get resolved with copies but it can also turn into | ||||
// swaps so we need to be careful about order of copies. | // swaps so we need to be careful about order of copies. | ||||
status = | status = | ||||
yieldOp->emitError() | yieldOp->emitError() | ||||
<< "Yield operand #" << operand.getOperandNumber() | << "Yield operand #" << operand.getOperandNumber() | ||||
<< " does not bufferize to an equivalent buffer to the matching" | << " does not bufferize to an equivalent buffer to the matching" | ||||
<< " enclosing scf::for operand"; | << " enclosing scf::for operand"; | ||||
return WalkResult::interrupt(); | return WalkResult::interrupt(); | ||||
} | } | ||||
} | } | ||||
} | |||||
if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) { | |||||
// IfOps are in destination passing style if all yielded tensors are | |||||
// a value or equivalent to a value that is defined outside of the IfOp. | |||||
for (OpOperand &operand : yieldOp->getOpOperands()) { | |||||
auto tensorType = operand.get().getType().dyn_cast<TensorType>(); | |||||
if (!tensorType) | |||||
continue; | |||||
bool foundOutsideEquivalent = false; | |||||
aliasInfo.applyOnEquivalenceClass(operand.get(), [&](Value value) { | |||||
Operation *valueOp = value.getDefiningOp(); | |||||
if (value.isa<BlockArgument>()) | |||||
valueOp = value.cast<BlockArgument>().getOwner()->getParentOp(); | |||||
bool inThenBlock = ifOp.thenBlock()->findAncestorOpInBlock(*valueOp); | |||||
bool inElseBlock = ifOp.elseBlock()->findAncestorOpInBlock(*valueOp); | |||||
if (!inThenBlock && !inElseBlock) | |||||
foundOutsideEquivalent = true; | |||||
}); | |||||
if (!foundOutsideEquivalent) { | |||||
status = yieldOp->emitError() | |||||
<< "Yield operand #" << operand.getOperandNumber() | |||||
<< " does not bufferize to a buffer that is equivalent to a" | |||||
<< " buffer defined outside of the scf::if op"; | |||||
return WalkResult::interrupt(); | |||||
} | |||||
} | |||||
} | |||||
return WalkResult::advance(); | return WalkResult::advance(); | ||||
}); | }); | ||||
return status; | return status; | ||||
} | } | ||||
struct YieldOpInterface | struct YieldOpInterface | ||||
: public BufferizableOpInterface::ExternalModel<YieldOpInterface, | : public BufferizableOpInterface::ExternalModel<YieldOpInterface, | ||||
Show All 38 Lines |