diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -75,41 +75,17 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto executeRegionOp = cast(op); - - // Compute new result types. - SmallVector newResultTypes; - for (Type type : executeRegionOp->getResultTypes()) { - if (auto tensorType = type.dyn_cast()) { - // TODO: Infer the result type instead of computing it. - newResultTypes.push_back(getMemRefType(tensorType, options)); - } else { - newResultTypes.push_back(type); - } - } + assert(executeRegionOp.getRegion().getBlocks().size() == 1 && + "only 1 block supported"); + auto yieldOp = + cast(executeRegionOp.getRegion().front().getTerminator()); + TypeRange newResultTypes(yieldOp.getResults()); // Create new op and move over region. auto newOp = rewriter.create(op->getLoc(), newResultTypes); newOp.getRegion().takeBody(executeRegionOp.getRegion()); - // Update terminator. - assert(newOp.getRegion().getBlocks().size() == 1 && - "only 1 block supported"); - Block *newBlock = &newOp.getRegion().front(); - auto yieldOp = cast(newBlock->getTerminator()); - rewriter.setInsertionPoint(yieldOp); - SmallVector newYieldValues; - for (const auto &it : llvm::enumerate(yieldOp.getResults())) { - Value val = it.value(); - if (val.getType().isa()) { - newYieldValues.push_back(rewriter.create( - yieldOp.getLoc(), newResultTypes[it.index()], val)); - } else { - newYieldValues.push_back(val); - } - } - rewriter.replaceOpWithNewOp(yieldOp, newYieldValues); - // Update all uses of the old op. rewriter.setInsertionPointAfter(newOp); SmallVector newResults; @@ -184,64 +160,62 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { + OpBuilder::InsertionGuard g(rewriter); auto ifOp = cast(op); - - // Compute new types of the bufferized scf.if op. - SmallVector newTypes; - for (Type returnType : ifOp->getResultTypes()) { - if (auto tensorType = returnType.dyn_cast()) { - // TODO: Infer the result type instead of computing it. - newTypes.push_back(getMemRefType(tensorType, options)); - } else { - newTypes.push_back(returnType); + auto thenYieldOp = cast(ifOp.thenBlock()->getTerminator()); + auto elseYieldOp = cast(ifOp.elseBlock()->getTerminator()); + + // Reconcile type mismatches between then/else branches by inserting memref + // casts. + SmallVector thenResults, elseResults; + bool insertedCast = false; + for (unsigned i = 0; i < thenYieldOp.getResults().size(); ++i) { + Value thenValue = thenYieldOp.getResults()[i]; + Value elseValue = elseYieldOp.getResults()[i]; + if (thenValue.getType() == elseValue.getType()) { + thenResults.push_back(thenValue); + elseResults.push_back(elseValue); + continue; } + + // Type mismatch between then/else yield value. Cast both to a memref type + // with a fully dynamic layout map. + auto thenMemrefType = thenValue.getType().cast(); + auto elseMemrefType = elseValue.getType().cast(); + if (thenMemrefType.getMemorySpaceAsInt() != + elseMemrefType.getMemorySpaceAsInt()) + return op->emitError("inconsistent memory space on then/else branches"); + rewriter.setInsertionPoint(thenYieldOp); + BaseMemRefType memrefType = getMemRefTypeWithFullyDynamicLayout( + ifOp.getResultTypes()[i].cast(), + thenMemrefType.getMemorySpaceAsInt()); + thenResults.push_back(rewriter.create( + thenYieldOp.getLoc(), memrefType, thenValue)); + rewriter.setInsertionPoint(elseYieldOp); + elseResults.push_back(rewriter.create( + elseYieldOp.getLoc(), memrefType, elseValue)); + insertedCast = true; + } + + if (insertedCast) { + rewriter.setInsertionPoint(thenYieldOp); + rewriter.replaceOpWithNewOp(thenYieldOp, thenResults); + rewriter.setInsertionPoint(elseYieldOp); + rewriter.replaceOpWithNewOp(elseYieldOp, elseResults); } // Create new op. + rewriter.setInsertionPoint(ifOp); + ValueRange resultsValueRange(thenResults); + TypeRange newTypes(resultsValueRange); auto newIfOp = rewriter.create(ifOp.getLoc(), newTypes, ifOp.getCondition(), /*withElseRegion=*/true); - // Remove terminators. - if (!newIfOp.thenBlock()->empty()) { - rewriter.eraseOp(newIfOp.thenBlock()->getTerminator()); - rewriter.eraseOp(newIfOp.elseBlock()->getTerminator()); - } - // Move over then/else blocks. rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock()); rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock()); - // Update scf.yield of new then-block. - auto thenYieldOp = cast(newIfOp.thenBlock()->getTerminator()); - rewriter.setInsertionPoint(thenYieldOp); - SmallVector thenYieldValues; - for (OpOperand &operand : thenYieldOp->getOpOperands()) { - if (operand.get().getType().isa()) { - ensureToMemrefOpIsValid(operand.get(), - newTypes[operand.getOperandNumber()]); - Value toMemrefOp = rewriter.create( - operand.get().getLoc(), newTypes[operand.getOperandNumber()], - operand.get()); - operand.set(toMemrefOp); - } - } - - // Update scf.yield of new else-block. - auto elseYieldOp = cast(newIfOp.elseBlock()->getTerminator()); - rewriter.setInsertionPoint(elseYieldOp); - SmallVector elseYieldValues; - for (OpOperand &operand : elseYieldOp->getOpOperands()) { - if (operand.get().getType().isa()) { - ensureToMemrefOpIsValid(operand.get(), - newTypes[operand.getOperandNumber()]); - Value toMemrefOp = rewriter.create( - operand.get().getLoc(), newTypes[operand.getOperandNumber()], - operand.get()); - operand.set(toMemrefOp); - } - } - // Replace op results. replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); @@ -869,6 +843,24 @@ if (!isa( yieldOp->getParentOp())) return yieldOp->emitError("unsupported scf::YieldOp parent"); + + // TODO: Bufferize scf.yield inside scf.while/scf.for here. + // (Currently bufferized together with scf.while/scf.for.) + if (isa(yieldOp->getParentOp())) + return success(); + + SmallVector newResults; + for (const auto &it : llvm::enumerate(yieldOp.getResults())) { + Value value = it.value(); + if (value.getType().isa()) { + Value buffer = getBuffer(rewriter, value, options); + newResults.push_back(buffer); + } else { + newResults.push_back(value); + } + } + + replaceOpWithNewBufferizedOp(rewriter, op, newResults); return success(); } }; diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir @@ -8,6 +8,7 @@ // CHECK-LABEL: func @buffer_not_deallocated( // CHECK-SAME: %[[t:.*]]: tensor func.func @buffer_not_deallocated(%t : tensor, %c : i1) -> tensor { + // CHECK: %[[m:.*]] = bufferization.to_memref %[[t]] // CHECK: %[[r:.*]] = scf.if %{{.*}} { %r = scf.if %c -> tensor { // CHECK: %[[some_op:.*]] = "test.some_op" @@ -20,7 +21,6 @@ scf.yield %0 : tensor } else { // CHECK: } else { - // CHECK: %[[m:.*]] = bufferization.to_memref %[[t]] // CHECK: %[[cloned:.*]] = bufferization.clone %[[m]] // CHECK: scf.yield %[[cloned]] scf.yield %t : tensor @@ -40,8 +40,8 @@ %cond: i1, %val: i32) -> tensor { + // CHECK: %[[arg0_m:.*]] = bufferization.to_memref %[[arg0]] // CHECK: %[[r:.*]] = scf.if {{.*}} { - // CHECK: %[[arg0_m:.*]] = bufferization.to_memref %[[arg0]] // CHECK: %[[clone:.*]] = bufferization.clone %[[arg0_m]] // CHECK: scf.yield %[[clone]] // CHECK: } else { diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir @@ -206,9 +206,9 @@ // CHECK-SCF-SAME: %[[t1:.*]]: tensor {bufferization.writable = true}, %[[c:.*]]: i1, %[[pos:.*]]: index func.func @simple_scf_if(%t1: tensor {bufferization.writable = true}, %c: i1, %pos: index, %f: f32) -> (tensor, index) { + // CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]] // CHECK-SCF: %[[r:.*]] = scf.if %[[c]] -> (memref) { %r1, %r2 = scf.if %c -> (tensor, index) { - // CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]] // CHECK-SCF: scf.yield %[[t1_memref]] scf.yield %t1, %pos : tensor, index // CHECK-SCF: } else { diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -124,11 +124,10 @@ scf.yield %f1, %t2, %f1 : f32, tensor, f32 } - // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] // CHECK: %[[load:.*]] = memref.load %[[m1]] %3 = tensor.extract %t1[%idx] : tensor - // CHECK: return %{{.*}}, %[[casted]], %[[load]] : f32, memref, f32 + // CHECK: return %{{.*}}, %[[alloc]], %[[load]] : f32, memref, f32 return %0, %1, %3 : f32, tensor, f32 }