diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -75,41 +75,17 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto executeRegionOp = cast(op); - - // Compute new result types. - SmallVector newResultTypes; - for (Type type : executeRegionOp->getResultTypes()) { - if (auto tensorType = type.dyn_cast()) { - // TODO: Infer the result type instead of computing it. - newResultTypes.push_back(getMemRefType(tensorType, options)); - } else { - newResultTypes.push_back(type); - } - } + assert(executeRegionOp.getRegion().getBlocks().size() == 1 && + "only 1 block supported"); + auto yieldOp = + cast(executeRegionOp.getRegion().front().getTerminator()); + TypeRange newResultTypes(yieldOp.getResults()); // Create new op and move over region. auto newOp = rewriter.create(op->getLoc(), newResultTypes); newOp.getRegion().takeBody(executeRegionOp.getRegion()); - // Update terminator. - assert(newOp.getRegion().getBlocks().size() == 1 && - "only 1 block supported"); - Block *newBlock = &newOp.getRegion().front(); - auto yieldOp = cast(newBlock->getTerminator()); - rewriter.setInsertionPoint(yieldOp); - SmallVector newYieldValues; - for (const auto &it : llvm::enumerate(yieldOp.getResults())) { - Value val = it.value(); - if (val.getType().isa()) { - newYieldValues.push_back(rewriter.create( - yieldOp.getLoc(), newResultTypes[it.index()], val)); - } else { - newYieldValues.push_back(val); - } - } - rewriter.replaceOpWithNewOp(yieldOp, newYieldValues); - // Update all uses of the old op. rewriter.setInsertionPointAfter(newOp); SmallVector newResults; @@ -184,64 +160,62 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { + OpBuilder::InsertionGuard g(rewriter); auto ifOp = cast(op); - - // Compute new types of the bufferized scf.if op. - SmallVector newTypes; - for (Type returnType : ifOp->getResultTypes()) { - if (auto tensorType = returnType.dyn_cast()) { - // TODO: Infer the result type instead of computing it. - newTypes.push_back(getMemRefType(tensorType, options)); - } else { - newTypes.push_back(returnType); + auto thenYieldOp = cast(ifOp.thenBlock()->getTerminator()); + auto elseYieldOp = cast(ifOp.elseBlock()->getTerminator()); + + // Reconcile type mismatches between then/else branches by inserting memref + // casts. + SmallVector thenResults, elseResults; + bool insertedCast = false; + for (int64_t i = 0; i < thenYieldOp.getResults().size(); ++i) { + Value thenValue = thenYieldOp.getResults()[i]; + Value elseValue = elseYieldOp.getResults()[i]; + if (thenValue.getType() == elseValue.getType()) { + thenResults.push_back(thenValue); + elseResults.push_back(elseValue); + continue; } + + // Type mismatch between then/else yield value. Cast both to a memref type + // with a fully dynamic layout map. + auto thenMemrefType = thenValue.getType().cast(); + auto elseMemrefType = elseValue.getType().cast(); + if (thenMemrefType.getMemorySpaceAsInt() != + elseMemrefType.getMemorySpaceAsInt()) + return op->emitError("inconsistent memory space on then/else branches"); + rewriter.setInsertionPoint(thenYieldOp); + BaseMemRefType memrefType = getMemRefTypeWithFullyDynamicLayout( + ifOp.getResultTypes()[i].cast(), + thenMemrefType.getMemorySpaceAsInt()); + thenResults.push_back(rewriter.create( + thenYieldOp.getLoc(), memrefType, thenValue)); + rewriter.setInsertionPoint(elseYieldOp); + elseResults.push_back(rewriter.create( + elseYieldOp.getLoc(), memrefType, elseValue)); + insertedCast = true; + } + + if (insertedCast) { + rewriter.setInsertionPoint(thenYieldOp); + rewriter.replaceOpWithNewOp(thenYieldOp, thenResults); + rewriter.setInsertionPoint(elseYieldOp); + rewriter.replaceOpWithNewOp(elseYieldOp, elseResults); } // Create new op. + rewriter.setInsertionPoint(ifOp); + ValueRange resultsValueRange(thenResults); + TypeRange newTypes(resultsValueRange); auto newIfOp = rewriter.create(ifOp.getLoc(), newTypes, ifOp.getCondition(), /*withElseRegion=*/true); - // Remove terminators. - if (!newIfOp.thenBlock()->empty()) { - rewriter.eraseOp(newIfOp.thenBlock()->getTerminator()); - rewriter.eraseOp(newIfOp.elseBlock()->getTerminator()); - } - // Move over then/else blocks. rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock()); rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock()); - // Update scf.yield of new then-block. - auto thenYieldOp = cast(newIfOp.thenBlock()->getTerminator()); - rewriter.setInsertionPoint(thenYieldOp); - SmallVector thenYieldValues; - for (OpOperand &operand : thenYieldOp->getOpOperands()) { - if (operand.get().getType().isa()) { - ensureToMemrefOpIsValid(operand.get(), - newTypes[operand.getOperandNumber()]); - Value toMemrefOp = rewriter.create( - operand.get().getLoc(), newTypes[operand.getOperandNumber()], - operand.get()); - operand.set(toMemrefOp); - } - } - - // Update scf.yield of new else-block. - auto elseYieldOp = cast(newIfOp.elseBlock()->getTerminator()); - rewriter.setInsertionPoint(elseYieldOp); - SmallVector elseYieldValues; - for (OpOperand &operand : elseYieldOp->getOpOperands()) { - if (operand.get().getType().isa()) { - ensureToMemrefOpIsValid(operand.get(), - newTypes[operand.getOperandNumber()]); - Value toMemrefOp = rewriter.create( - operand.get().getLoc(), newTypes[operand.getOperandNumber()], - operand.get()); - operand.set(toMemrefOp); - } - } - // Replace op results. replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); @@ -471,6 +445,13 @@ return success(); } + BaseMemRefType getBufferType(Operation *op, BlockArgument bbArg, + const BufferizationOptions &options) const { + auto forOp = cast(op); + return bufferization::getBufferType( + forOp.getOpOperandForRegionIterArg(bbArg).get(), options); + } + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto forOp = cast(op); @@ -500,20 +481,9 @@ getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices); iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); - // Erase terminator if present. - if (iterArgs.size() == 1) - rewriter.eraseOp(loopBody->getTerminator()); - // Move loop body to new loop. rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs); - // Update scf.yield of new loop. - auto yieldOp = cast(loopBody->getTerminator()); - rewriter.setInsertionPoint(yieldOp); - SmallVector yieldValues = getYieldedValues( - rewriter, yieldOp.getResults(), initArgsTypes, indices, options); - yieldOp.getResultsMutable().assign(yieldValues); - // Replace loop results. replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); @@ -869,6 +839,31 @@ if (!isa( yieldOp->getParentOp())) return yieldOp->emitError("unsupported scf::YieldOp parent"); + + // TODO: Bufferize scf.yield inside scf.while here. (Currently bufferized + // together with scf.while.) + if (isa(yieldOp->getParentOp())) + return success(); + + SmallVector newResults; + for (const auto &it : llvm::enumerate(yieldOp.getResults())) { + Value value = it.value(); + if (value.getType().isa()) { + Value buffer = getBuffer(rewriter, value, options); + if (auto forOp = dyn_cast(yieldOp->getParentOp())) { + BaseMemRefType resultType = + cast(forOp.getOperation()) + .getBufferType(forOp.getRegionIterArgs()[it.index()], + options); + buffer = castBuffer(rewriter, buffer, resultType); + } + newResults.push_back(buffer); + } else { + newResults.push_back(value); + } + } + + replaceOpWithNewBufferizedOp(rewriter, op, newResults); return success(); } }; @@ -949,64 +944,32 @@ return success(); } - LogicalResult bufferize(Operation *op, RewriterBase &b, + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { - OpBuilder::InsertionGuard g(b); auto foreachThreadOp = cast(op); - // Gather new results of the ForeachThreadOp. - SmallVector newResults; - for (OpResult opResult : foreachThreadOp->getOpResults()) { - OpOperand *insertDest = - getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]; - // Insert copies right before the PerformConcurrentlyOp terminator. They - // should not be inside terminator (which would be the default insertion - // point). - Value buffer = getBuffer(b, insertDest->get(), options); - newResults.push_back(buffer); - } +#ifndef NDEBUG + // ParallelInsertSliceOpInterface replaces all uses. + for (OpResult opResult : foreachThreadOp->getOpResults()) + assert(opResult.getUses().empty() && + "expected that all uses were already replaced"); +#endif // NDEBUG // Create new ForeachThreadOp without any results and drop the automatically // introduced terminator. TypeRange newResultTypes; - auto newForeachThreadOp = - b.create(foreachThreadOp.getLoc(), newResultTypes, - foreachThreadOp.getNumThreads()); + auto newForeachThreadOp = rewriter.create( + foreachThreadOp.getLoc(), newResultTypes, + foreachThreadOp.getNumThreads()); newForeachThreadOp.getBody()->getTerminator()->erase(); // Move over block contents of the old op. - b.mergeBlocks(foreachThreadOp.getBody(), newForeachThreadOp.getBody(), - {newForeachThreadOp.getBody()->getArguments()}); - - // Bufferize terminator. - auto performConcurrentlyOp = cast( - newForeachThreadOp.getBody()->getTerminator()); - b.setInsertionPoint(performConcurrentlyOp); - unsigned resultCounter = 0; - WalkResult walkResult = - performConcurrentlyOp.walk([&](ParallelInsertSliceOp insertOp) { - Location loc = insertOp.getLoc(); - Type srcType = getMemRefType( - insertOp.getSource().getType().cast(), options); - // ParallelInsertSliceOp bufferizes to a copy. - auto srcMemref = b.create( - loc, srcType, insertOp.getSource()); - Value destMemref = newResults[resultCounter++]; - Value subview = b.create( - loc, destMemref, insertOp.getMixedOffsets(), - insertOp.getMixedSizes(), insertOp.getMixedStrides()); - // This memcpy will fold away if everything bufferizes in-place. - if (failed(options.createMemCpy(b, insertOp.getLoc(), srcMemref, - subview))) - return WalkResult::interrupt(); - b.eraseOp(insertOp); - return WalkResult::advance(); - }); - if (walkResult.wasInterrupted()) - return failure(); + rewriter.mergeBlocks(foreachThreadOp.getBody(), + newForeachThreadOp.getBody(), + {newForeachThreadOp.getBody()->getArguments()}); - // Replace the op. - replaceOpWithBufferizedValues(b, op, newResults); + // Remove the old op. + rewriter.eraseOp(op); return success(); } @@ -1104,9 +1067,50 @@ return success(); } - LogicalResult bufferize(Operation *op, RewriterBase &b, + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { - // Will be bufferized as part of ForeachThreadOp. + OpBuilder::InsertionGuard g(rewriter); + auto insertOp = cast(op); + auto performConcurrentlyOp = cast(op->getParentOp()); + auto foreachThreadOp = + cast(performConcurrentlyOp->getParentOp()); + + // If the op bufferizes out-of-place, allocate the copy before the + // ForeachThreadOp. + rewriter.setInsertionPoint(foreachThreadOp); + Value destBuffer = getBuffer(rewriter, insertOp.getDest(), options); + + // Bufferize the ParallelInsertSliceOp outside of the PerformConcurrentlyOp. + rewriter.setInsertionPoint(performConcurrentlyOp); + Value srcBuffer = getBuffer(rewriter, insertOp.getSource(), options); + Value subview = rewriter.create( + insertOp.getLoc(), destBuffer, insertOp.getMixedOffsets(), + insertOp.getMixedSizes(), insertOp.getMixedStrides()); + // This memcpy will fold away if everything bufferizes in-place. + if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), srcBuffer, + subview))) + return failure(); + rewriter.eraseOp(op); + + // Replace all uses of ForeachThreadOp (just the corresponding result). + rewriter.setInsertionPointAfter(foreachThreadOp); + Value toTensorOp = + rewriter.create(foreachThreadOp.getLoc(), destBuffer); + unsigned resultNum = 0; + for (Operation &nextOp : performConcurrentlyOp.yieldingOps()) { + if (&nextOp == op) + break; + resultNum++; + } + assert(resultNum < foreachThreadOp->getNumResults() && + "ParallelInsertSliceOp not found in PerformConcurrentlyOp"); + SmallVector resultUses = llvm::to_vector( + llvm::map_range(foreachThreadOp->getResult(resultNum).getUses(), + [](OpOperand &use) { return &use; })); + for (OpOperand *use : resultUses) { + rewriter.updateRootInPlace(use->getOwner(), + [&]() { use->set(toTensorOp); }); + } return success(); } diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -61,42 +61,17 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto assumingOp = cast(op); - - // Compute new result types. - SmallVector newResultTypes; - for (Type type : assumingOp->getResultTypes()) { - if (auto tensorType = type.dyn_cast()) { - // TODO: Infer the result type instead of computing it. - newResultTypes.push_back(getMemRefType(tensorType, options)); - } else { - newResultTypes.push_back(type); - } - } + assert(assumingOp.getDoRegion().getBlocks().size() == 1 && + "only 1 block supported"); + auto yieldOp = cast( + assumingOp.getDoRegion().front().getTerminator()); // Create new op and move over region. + TypeRange newResultTypes(yieldOp.operands()); auto newOp = rewriter.create( op->getLoc(), newResultTypes, assumingOp.getWitness()); newOp.getDoRegion().takeBody(assumingOp.getRegion()); - // Update terminator. - assert(newOp.getDoRegion().getBlocks().size() == 1 && - "only 1 block supported"); - Block *newBlock = &newOp.getDoRegion().front(); - auto yieldOp = cast(newBlock->getTerminator()); - rewriter.setInsertionPoint(yieldOp); - SmallVector newYieldValues; - for (const auto &it : llvm::enumerate(yieldOp.operands())) { - Value val = it.value(); - if (val.getType().isa()) { - newYieldValues.push_back(rewriter.create( - yieldOp.getLoc(), newResultTypes[it.index()], val)); - } else { - newYieldValues.push_back(val); - } - } - rewriter.replaceOpWithNewOp(yieldOp, - newYieldValues); - // Update all uses of the old op. rewriter.setInsertionPointAfter(newOp); SmallVector newResults; @@ -153,7 +128,14 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { - // Op is bufferized as part of AssumingOp. + auto yieldOp = cast(op); + SmallVector newResults; + for (Value value : yieldOp.operands()) + newResults.push_back(value.getType().isa() + ? getBuffer(rewriter, value, options) + : value); + replaceOpWithNewBufferizedOp(rewriter, op, + newResults); return success(); } }; diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir @@ -8,6 +8,7 @@ // CHECK-LABEL: func @buffer_not_deallocated( // CHECK-SAME: %[[t:.*]]: tensor func.func @buffer_not_deallocated(%t : tensor, %c : i1) -> tensor { + // CHECK: %[[m:.*]] = bufferization.to_memref %[[t]] // CHECK: %[[r:.*]] = scf.if %{{.*}} { %r = scf.if %c -> tensor { // CHECK: %[[some_op:.*]] = "test.some_op" @@ -20,7 +21,6 @@ scf.yield %0 : tensor } else { // CHECK: } else { - // CHECK: %[[m:.*]] = bufferization.to_memref %[[t]] // CHECK: %[[cloned:.*]] = bufferization.clone %[[m]] // CHECK: scf.yield %[[cloned]] scf.yield %t : tensor @@ -40,8 +40,8 @@ %cond: i1, %val: i32) -> tensor { + // CHECK: %[[arg0_m:.*]] = bufferization.to_memref %[[arg0]] // CHECK: %[[r:.*]] = scf.if {{.*}} { - // CHECK: %[[arg0_m:.*]] = bufferization.to_memref %[[arg0]] // CHECK: %[[clone:.*]] = bufferization.clone %[[arg0_m]] // CHECK: scf.yield %[[clone]] // CHECK: } else { diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir @@ -206,9 +206,9 @@ // CHECK-SCF-SAME: %[[t1:.*]]: tensor {bufferization.writable = true}, %[[c:.*]]: i1, %[[pos:.*]]: index func.func @simple_scf_if(%t1: tensor {bufferization.writable = true}, %c: i1, %pos: index, %f: f32) -> (tensor, index) { + // CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]] // CHECK-SCF: %[[r:.*]] = scf.if %[[c]] -> (memref) { %r1, %r2 = scf.if %c -> (tensor, index) { - // CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]] // CHECK-SCF: scf.yield %[[t1_memref]] scf.yield %t1, %pos : tensor, index // CHECK-SCF: } else { diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -124,11 +124,10 @@ scf.yield %f1, %t2, %f1 : f32, tensor, f32 } - // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] // CHECK: %[[load:.*]] = memref.load %[[m1]] %3 = tensor.extract %t1[%idx] : tensor - // CHECK: return %{{.*}}, %[[casted]], %[[load]] : f32, memref, f32 + // CHECK: return %{{.*}}, %[[alloc]], %[[load]] : f32, memref, f32 return %0, %1, %3 : f32, tensor, f32 }