diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -543,11 +543,11 @@ Optional overrideInPlace = None, Optional customCopyInsertionPoint = None); - /// Return the buffer type for a given OpOperand (tensor) after bufferization. + /// Return the buffer type for a given Value (tensor) after bufferization. /// /// Note: Op implementations should preferrably call `getBuffer()->getType()`. /// This function should only be used if `getBuffer` cannot be used. - BaseMemRefType getBufferType(OpOperand &opOperand) const; + BaseMemRefType getBufferType(Value value) const; /// Return a reference to the BufferizationOptions. const BufferizationOptions &getOptions() const { diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -333,13 +333,12 @@ return resultBuffer; } -/// Return the buffer type for a given OpOperand (tensor) after bufferization. -BaseMemRefType BufferizationState::getBufferType(OpOperand &opOperand) const { - Value tensor = opOperand.get(); - auto tensorType = tensor.getType().dyn_cast(); +/// Return the buffer type for a given Value (tensor) after bufferization. +BaseMemRefType BufferizationState::getBufferType(Value value) const { + auto tensorType = value.getType().dyn_cast(); assert(tensorType && "unexpected non-tensor type"); - if (auto toTensorOp = tensor.getDefiningOp()) + if (auto toTensorOp = value.getDefiningOp()) return toTensorOp.memref().getType().cast(); return getMemRefType(tensorType, getOptions()); diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -276,14 +276,14 @@ DenseSet getEquivalentBuffers(Block::BlockArgListType bbArgs, ValueRange yieldedValues, const AnalysisState &state) { + unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size()); DenseSet result; - int64_t counter = 0; - for (const auto &it : llvm::zip(bbArgs, yieldedValues)) { - if (!std::get<0>(it).getType().isa()) + for (unsigned int i = 0; i < minSize; ++i) { + if (!bbArgs[i].getType().isa() || + !yieldedValues[i].getType().isa()) continue; - if (state.areEquivalentBufferizedValues(std::get<0>(it), std::get<1>(it))) - result.insert(counter); - counter++; + if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i])) + result.insert(i); } return result; } @@ -486,8 +486,6 @@ // The new memref init_args of the loop. SmallVector initArgs = getBuffers(rewriter, forOp.getIterOpOperands(), state); - if (initArgs.size() != indices.size()) - return failure(); // Construct a new scf.for op with memref instead of tensor values. auto newForOp = rewriter.create( @@ -578,7 +576,16 @@ SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto whileOp = cast(op); - return {whileOp->getResult(opOperand.getOperandNumber())}; + unsigned int idx = opOperand.getOperandNumber(); + + // The OpResults and OpOperands may not match. They may not even have the + // same type. The number of OpResults and OpOperands can also differ. + if (idx >= op->getNumResults() || + opOperand.get().getType() != op->getResult(idx).getType()) + return {}; + + // The only aliasing OpResult may be the one at the same index. + return {whileOp->getResult(idx)}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, @@ -589,6 +596,13 @@ unsigned int resultNumber = opResult.getResultNumber(); auto whileOp = cast(op); + // The "before" region bbArgs and the OpResults may not match. + if (resultNumber >= whileOp.getBeforeArguments().size()) + return BufferRelation::None; + if (opResult.getType() != + whileOp.getBeforeArguments()[resultNumber].getType()) + return BufferRelation::None; + auto conditionOp = whileOp.getConditionOp(); BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber]; Value conditionOperand = conditionOp.getArgs()[resultNumber]; @@ -627,9 +641,12 @@ "regions with multiple blocks not supported"); Block *afterBody = &whileOp.getAfter().front(); - // Indices of all iter_args that have tensor type. These are the ones that - // are bufferized. - DenseSet indices = getTensorIndices(whileOp.getInits()); + // Indices of all bbArgs that have tensor type. These are the ones that + // are bufferized. The "before" and "after" regions may have different args. + DenseSet indicesBefore = getTensorIndices(whileOp.getInits()); + DenseSet indicesAfter = + getTensorIndices(whileOp.getAfterArguments()); + // For every yielded value, is the value equivalent to its corresponding // bbArg? DenseSet equivalentYieldsBefore = getEquivalentBuffers( @@ -642,51 +659,64 @@ // The new memref init_args of the loop. SmallVector initArgs = getBuffers(rewriter, whileOp->getOpOperands(), state); - if (initArgs.size() != indices.size()) - return failure(); + + // The result types of a WhileOp are the same as the "after" bbArg types. + SmallVector argsTypesAfter = llvm::to_vector( + llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) { + return state.getBufferType(bbArg).cast(); + })); // Construct a new scf.while op with memref instead of tensor values. - ValueRange argsRange(initArgs); - TypeRange argsTypes(argsRange); - auto newWhileOp = - rewriter.create(whileOp.getLoc(), argsTypes, initArgs); + ValueRange argsRangeBefore(initArgs); + TypeRange argsTypesBefore(argsRangeBefore); + auto newWhileOp = rewriter.create(whileOp.getLoc(), + argsTypesAfter, initArgs); + // Add before/after regions to the new op. - SmallVector bbArgLocs(initArgs.size(), whileOp.getLoc()); + SmallVector bbArgLocsBefore(initArgs.size(), whileOp.getLoc()); + SmallVector bbArgLocsAfter(argsTypesAfter.size(), + whileOp.getLoc()); Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock(); - newWhileOp.getBefore().addArguments(argsTypes, bbArgLocs); + newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore); Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock(); - newWhileOp.getAfter().addArguments(argsTypes, bbArgLocs); + newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter); // Set up new iter_args and move the loop condition block to the new op. // The old block uses tensors, so wrap the (memref) bbArgs of the new block // in ToTensorOps. rewriter.setInsertionPointToStart(newBeforeBody); SmallVector newBeforeArgs = getBbArgReplacements( - rewriter, newWhileOp.getBeforeArguments(), indices); + rewriter, newWhileOp.getBeforeArguments(), indicesBefore); rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs); // Update scf.condition of new loop. auto newConditionOp = newWhileOp.getConditionOp(); rewriter.setInsertionPoint(newConditionOp); + // Only equivalent buffers or new buffer allocations may be yielded to the + // "after" region. + // TODO: This could be relaxed for better bufferization results. SmallVector newConditionArgs = - getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypes, indices, - equivalentYieldsBefore, state); + getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter, + indicesAfter, equivalentYieldsBefore, state); newConditionOp.getArgsMutable().assign(newConditionArgs); // Set up new iter_args and move the loop body block to the new op. // The old block uses tensors, so wrap the (memref) bbArgs of the new block // in ToTensorOps. rewriter.setInsertionPointToStart(newAfterBody); - SmallVector newAfterArgs = - getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(), indices); + SmallVector newAfterArgs = getBbArgReplacements( + rewriter, newWhileOp.getAfterArguments(), indicesAfter); rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs); // Update scf.yield of the new loop. auto newYieldOp = newWhileOp.getYieldOp(); rewriter.setInsertionPoint(newYieldOp); + // Only equivalent buffers or new buffer allocations may be yielded to the + // "before" region. + // TODO: This could be relaxed for better bufferization results. SmallVector newYieldValues = - getYieldedValues(rewriter, newYieldOp.getResults(), argsTypes, indices, - equivalentYieldsAfter, state); + getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore, + indicesBefore, equivalentYieldsAfter, state); newYieldOp.getResultsMutable().assign(newYieldValues); // Replace loop results. diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -111,7 +111,7 @@ auto collapseShapeOp = cast(op); RankedTensorType tensorResultType = collapseShapeOp.getResultType(); OpOperand &srcOperand = collapseShapeOp->getOpOperand(0) /*src*/; - auto bufferType = state.getBufferType(srcOperand).cast(); + auto bufferType = state.getBufferType(srcOperand.get()).cast(); if (tensorResultType.getRank() == 0) { // 0-d collapses must go through a different op builder. diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -449,3 +449,36 @@ // CHECK: return %[[loop]]#0, %[[loop]]#1 return %r0, %r1 : tensor<5xi1>, tensor<5xi1> } + +// ----- + +// CHECK-LABEL: func @scf_while_iter_arg_result_mismatch( +// CHECK-SAME: %[[arg0:.*]]: memref<5xi1, #{{.*}}>, %[[arg1:.*]]: memref<5xi1, #{{.*}}> +// CHECK: %[[alloc1:.*]] = memref.alloc() {{.*}} : memref<5xi1> +// CHECK: %[[alloc2:.*]] = memref.alloc() {{.*}} : memref<5xi1> +// CHECK: scf.while (%[[arg3:.*]] = %[[arg1]]) : (memref<5xi1, #{{.*}}) -> () { +// CHECK: %[[load:.*]] = memref.load %[[arg0]] +// CHECK: scf.condition(%[[load]]) +// CHECK: } do { +// CHECK: memref.copy %[[arg0]], %[[alloc2]] +// CHECK: memref.store %{{.*}}, %[[alloc2]] +// CHECK: memref.copy %[[alloc2]], %[[alloc1]] +// CHECK: %[[casted:.*]] = memref.cast %[[alloc1]] : memref<5xi1> to memref<5xi1, #{{.*}}> +// CHECK: scf.yield %[[casted]] +// CHECK: } +// CHECK-DAG: memref.dealloc %[[alloc1]] +// CHECK-DAG: memref.dealloc %[[alloc2]] +func.func @scf_while_iter_arg_result_mismatch(%arg0: tensor<5xi1>, + %arg1: tensor<5xi1>, + %arg2: index) { + scf.while (%arg3 = %arg1) : (tensor<5xi1>) -> () { + %0 = tensor.extract %arg0[%arg2] : tensor<5xi1> + scf.condition(%0) + } do { + %0 = "dummy.some_op"() : () -> index + %1 = "dummy.another_op"() : () -> i1 + %2 = tensor.insert %1 into %arg0[%0] : tensor<5xi1> + scf.yield %2 : tensor<5xi1> + } + return +}