diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -516,8 +516,6 @@ // Certain buffers are not writeable: // 1. A function bbArg that is not inplaceable or // 2. A constant op. - assert(!aliasesNonWritableBuffer(opResult, aliasInfo, state) && - "expected that opResult does not alias non-writable buffer"); bool nonWritable = aliasesNonWritableBuffer(opOperand.get(), aliasInfo, state); if (!nonWritable) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -131,27 +131,74 @@ BufferizationState &state) const { auto ifOp = cast(op); - // Bufferize then/else blocks. - if (failed(comprehensive_bufferize::bufferize(ifOp.thenBlock(), state))) - return failure(); - if (failed(comprehensive_bufferize::bufferize(ifOp.elseBlock(), state))) - return failure(); + // Use IRRewriter instead of OpBuilder because it has additional helper + // functions. + IRRewriter rewriter(op->getContext()); + rewriter.setInsertionPoint(ifOp); + + // Compute new types of the bufferized scf.if op. + SmallVector newTypes; + for (Type returnType : ifOp->getResultTypes()) { + if (returnType.isa()) { + assert(returnType.isa() && + "unsupported unranked tensor"); + newTypes.push_back( + getDynamicMemRefType(returnType.cast())); + } else { + newTypes.push_back(returnType); + } + } - for (OpResult opResult : ifOp->getResults()) { - if (!opResult.getType().isa()) - continue; - // TODO: Atm we bail on unranked TensorType because we don't know how to - // alloc an UnrankedMemRefType + its underlying ranked MemRefType. - assert(opResult.getType().isa() && - "unsupported unranked tensor"); + // Create new op. + auto newIfOp = + rewriter.create(ifOp.getLoc(), newTypes, ifOp.condition(), + /*withElseRegion=*/true); - Value resultBuffer = state.getResultBuffer(opResult); - if (!resultBuffer) - return failure(); + // Remove terminators. + if (!newIfOp.thenBlock()->empty()) { + rewriter.eraseOp(newIfOp.thenBlock()->getTerminator()); + rewriter.eraseOp(newIfOp.elseBlock()->getTerminator()); + } - state.mapBuffer(opResult, resultBuffer); + // Move over then/else blocks. + rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock()); + rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock()); + + // Update scf.yield of new then-block. + auto thenYieldOp = cast(newIfOp.thenBlock()->getTerminator()); + rewriter.setInsertionPoint(thenYieldOp); + SmallVector thenYieldValues; + for (OpOperand &operand : thenYieldOp->getOpOperands()) { + if (operand.get().getType().isa()) { + Value toMemrefOp = rewriter.create( + operand.get().getLoc(), newTypes[operand.getOperandNumber()], + operand.get()); + operand.set(toMemrefOp); + } } + // Update scf.yield of new else-block. + auto elseYieldOp = cast(newIfOp.elseBlock()->getTerminator()); + rewriter.setInsertionPoint(elseYieldOp); + SmallVector elseYieldValues; + for (OpOperand &operand : elseYieldOp->getOpOperands()) { + if (operand.get().getType().isa()) { + Value toMemrefOp = rewriter.create( + operand.get().getLoc(), newTypes[operand.getOperandNumber()], + operand.get()); + operand.set(toMemrefOp); + } + } + + // Replace op results. + state.replaceOp(op, newIfOp->getResults()); + + // Bufferize then/else blocks. + if (failed(comprehensive_bufferize::bufferize(newIfOp.thenBlock(), state))) + return failure(); + if (failed(comprehensive_bufferize::bufferize(newIfOp.elseBlock(), state))) + return failure(); + return success(); } @@ -299,27 +346,56 @@ SmallVector &newOps) { LogicalResult status = success(); op->walk([&](scf::YieldOp yieldOp) { - auto forOp = dyn_cast(yieldOp->getParentOp()); - if (!forOp) - return WalkResult::advance(); - - for (OpOperand &operand : yieldOp->getOpOperands()) { - auto tensorType = operand.get().getType().dyn_cast(); - if (!tensorType) - continue; - - OpOperand &forOperand = forOp.getOpOperandForResult( - forOp->getResult(operand.getOperandNumber())); - auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); - if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) { - // TODO: this could get resolved with copies but it can also turn into - // swaps so we need to be careful about order of copies. - status = - yieldOp->emitError() - << "Yield operand #" << operand.getOperandNumber() - << " does not bufferize to an equivalent buffer to the matching" - << " enclosing scf::for operand"; - return WalkResult::interrupt(); + if (auto forOp = dyn_cast(yieldOp->getParentOp())) { + for (OpOperand &operand : yieldOp->getOpOperands()) { + auto tensorType = operand.get().getType().dyn_cast(); + if (!tensorType) + continue; + + OpOperand &forOperand = forOp.getOpOperandForResult( + forOp->getResult(operand.getOperandNumber())); + auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); + if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) { + // TODO: this could get resolved with copies but it can also turn into + // swaps so we need to be careful about order of copies. + status = + yieldOp->emitError() + << "Yield operand #" << operand.getOperandNumber() + << " does not bufferize to an equivalent buffer to the matching" + << " enclosing scf::for operand"; + return WalkResult::interrupt(); + } + } + } + + if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { + // IfOps are in destination passing style if all yielded tensors are + // a value or equivalent to a value that is defined outside of the IfOp. + for (OpOperand &operand : yieldOp->getOpOperands()) { + auto tensorType = operand.get().getType().dyn_cast(); + if (!tensorType) + continue; + + bool foundOutsideEquivalent = false; + aliasInfo.applyOnEquivalenceClass(operand.get(), [&](Value value) { + Operation *valueOp = value.getDefiningOp(); + if (value.isa()) + valueOp = value.cast().getOwner()->getParentOp(); + + bool inThenBlock = ifOp.thenBlock()->findAncestorOpInBlock(*valueOp); + bool inElseBlock = ifOp.elseBlock()->findAncestorOpInBlock(*valueOp); + + if (!inThenBlock && !inElseBlock) + foundOutsideEquivalent = true; + }); + + if (!foundOutsideEquivalent) { + status = yieldOp->emitError() + << "Yield operand #" << operand.getOperandNumber() + << " does not bufferize to a buffer that is equivalent to a" + << " buffer defined outside of the scf::if op"; + return WalkResult::interrupt(); + } } } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -38,12 +38,12 @@ func @scf_if_not_equivalent( %cond: i1, %t1: tensor {linalg.inplaceable = true}, %idx: index) -> tensor { - // expected-error @+1 {{result buffer is ambiguous}} %r = scf.if %cond -> (tensor) { scf.yield %t1 : tensor } else { // This buffer aliases, but is not equivalent. %t2 = tensor.extract_slice %t1 [%idx] [%idx] [1] : tensor to tensor + // expected-error @+1 {{Yield operand #0 does not bufferize to a buffer that is equivalent to a buffer defined outside of the scf::if op}} scf.yield %t2 : tensor } return %r : tensor @@ -127,9 +127,9 @@ // ----- +// expected-error @+1 {{memref return type is unsupported}} func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32> { - // expected-error @+1 {{result buffer is ambiguous}} %r = scf.if %b -> (tensor<4xf32>) { scf.yield %A : tensor<4xf32> } else { diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir @@ -194,3 +194,28 @@ // CHECK-SCF: return %[[scf_for_tensor]] return %0 : tensor } + +// ----- + +// CHECK-SCF-LABEL: func @simple_scf_if( +// CHECK-SCF-SAME: %[[t1:.*]]: tensor {linalg.inplaceable = true}, %[[c:.*]]: i1, %[[pos:.*]]: index +func @simple_scf_if(%t1: tensor {linalg.inplaceable = true}, %c: i1, %pos: index, %f: f32) + -> (tensor, index) { + // CHECK-SCF: %[[r:.*]] = scf.if %[[c]] -> (memref) { + %r1, %r2 = scf.if %c -> (tensor, index) { + // CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]] + // CHECK-SCF: scf.yield %[[t1_memref]] + scf.yield %t1, %pos : tensor, index + // CHECK-SCF: } else { + } else { + // CHECK-SCF: %[[insert:.*]] = tensor.insert %{{.*}} into %[[t1]][{{.*}}] + // CHECK-SCF: %[[insert_memref:.*]] = bufferization.to_memref %[[insert]] + %1 = tensor.insert %f into %t1[%pos] : tensor + // CHECK-SCF: scf.yield %[[insert_memref]] + scf.yield %1, %pos : tensor, index + } + + // CHECK-SCF: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]] + // CHECK-SCF: return %[[r_tensor]], %[[pos]] + return %r1, %r2 : tensor, index +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -921,6 +921,22 @@ // ----- +// CHECK-LABEL: func @scf_if_non_equiv_yields( +// CHECK-SAME: %[[cond:.*]]: i1, %[[A:.*]]: memref<{{.*}}>, %[[B:.*]]: memref<{{.*}}>) -> memref<{{.*}}> +func @scf_if_non_equiv_yields(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32> +{ + // CHECK: %[[r:.*]] = select %[[cond]], %[[A]], %[[B]] + %r = scf.if %b -> (tensor<4xf32>) { + scf.yield %A : tensor<4xf32> + } else { + scf.yield %B : tensor<4xf32> + } + // CHECK: return %[[r]] + return %r: tensor<4xf32> +} + +// ----- + // CHECK-LABEL: func @insert_op // CHECK-SAME: %[[t1:.*]]: memref, %[[s:.*]]: f32, %[[i:.*]]: index func @insert_op(%t1 : tensor {linalg.inplaceable = true},