diff --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp @@ -27,6 +27,21 @@ OwningRewritePatternList patterns; ConversionTarget target(*context); + // TODO: Move this to BufferizeTypeConverter's constructor. + // + // This doesn't currently play well with "finalizing" bufferizations (ones + // that expect all materializations to be gone). In particular, there seems + // to at least be a double-free in the dialect conversion framework + // when this materialization gets inserted and then folded away because + // it is marked as illegal. + typeConverter.addArgumentMaterialization( + [](OpBuilder &builder, RankedTensorType type, ValueRange inputs, + Location loc) -> Value { + assert(inputs.size() == 1); + assert(inputs[0].getType().isa()); + return builder.create(loc, type, inputs[0]); + }); + populateBufferizeMaterializationLegality(target); populateSCFStructuralTypeConversionsAndLegality(context, typeConverter, patterns, target); diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -31,16 +31,44 @@ newResultTypes.push_back(newType); } - // Clone and replace. - ForOp newOp = cast(rewriter.clone(*op.getOperation())); + // Clone the op without the regions and inline the regions from the old op. + // + // This is a little bit tricky. We have two concerns here: + // + // 1. We cannot update the op in place because the dialect conversion + // framework does not track type changes for ops updated in place, so it + // won't insert appropriate materializations on the changed result types. + // PR47938 tracks this issue, but it seems hard to fix. Instead, we need to + // clone the op. + // + // 2. We cannot simply call `op.clone()` to get the cloned op. Besides being + // inefficient to recursively clone the regions, there is a correctness + // issue: if we clone with the regions, then the dialect conversion + // framework thinks that we just inserted all the cloned child ops. But what + // we want is to "take" the child regions and let the dialect conversion + // framework continue recursively into ops inside those regions (which are + // already in its worklist; inlining them into the new op's regions doesn't + // remove the child ops from the worklist). + ForOp newOp = cast(rewriter.cloneWithoutRegions(*op.getOperation())); + // Take the region from the old op and put it in the new op. + rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(), + newOp.getLoopBody().end()); + + // Now, update all the types. + + // Convert the type of the entry block of the ForOp's body. + if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(), + *getTypeConverter()))) { + return rewriter.notifyMatchFailure(op, "could not convert body types"); + } + // Change the clone to use the updated operands. We could have cloned with + // a BlockAndValueMapping, but this seems a bit more direct. newOp.getOperation()->setOperands(operands); + // Update the result types to the new converted types. for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) std::get<0>(t).setType(std::get<1>(t)); - auto bodyArgs = newOp.getBody()->getArguments(); - for (auto t : llvm::zip(llvm::drop_begin(bodyArgs, 1), newResultTypes)) - std::get<0>(t).setType(std::get<1>(t)); - rewriter.replaceOp(op, newOp.getResults()); + rewriter.replaceOp(op, newOp.getResults()); return success(); } }; @@ -71,9 +99,15 @@ newResultTypes.push_back(newType); } - // TODO: Write this with updateRootInPlace once the conversion infra - // supports source materializations on ops updated in place. - IfOp newOp = cast(rewriter.clone(*op.getOperation())); + // See comments in the ForOp pattern for why we clone without regions and + // then inline. + IfOp newOp = cast(rewriter.cloneWithoutRegions(*op.getOperation())); + rewriter.inlineRegionBefore(op.thenRegion(), newOp.thenRegion(), + newOp.thenRegion().end()); + rewriter.inlineRegionBefore(op.elseRegion(), newOp.elseRegion(), + newOp.elseRegion().end()); + + // Update the operands and types. newOp.getOperation()->setOperands(operands); for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) std::get<0>(t).setType(std::get<1>(t)); diff --git a/mlir/test/Dialect/SCF/bufferize.mlir b/mlir/test/Dialect/SCF/bufferize.mlir --- a/mlir/test/Dialect/SCF/bufferize.mlir +++ b/mlir/test/Dialect/SCF/bufferize.mlir @@ -29,7 +29,9 @@ // CHECK-SAME: %[[STEP:.*]]: index) -> tensor { // CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref // CHECK: %[[RESULT_MEMREF:.*]] = scf.for %[[VAL_6:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ITER:.*]] = %[[MEMREF]]) -> (memref) { -// CHECK: scf.yield %[[ITER]] : memref +// CHECK: %[[TENSOR_ITER:.*]] = tensor_load %[[ITER]] : memref +// CHECK: %[[MEMREF_YIELDED:.*]] = tensor_to_memref %[[TENSOR_ITER]] : memref +// CHECK: scf.yield %[[MEMREF_YIELDED]] : memref // CHECK: } // CHECK: %[[VAL_8:.*]] = tensor_load %[[VAL_9:.*]] : memref // CHECK: return %[[VAL_8]] : tensor @@ -40,3 +42,40 @@ } return %ret : tensor } + +// Check whether this converts at all. +// +// It would previously fail altogether. +// CHECK-LABEL: func @if_correct_recursive_legalization_behavior +// CHECK: "test.munge_tensor" +func @if_correct_recursive_legalization_behavior(%pred: i1, %tensor: tensor) -> tensor { + %0 = scf.if %pred -> (tensor) { + %1 = "test.munge_tensor"(%tensor) : (tensor) -> (tensor) + scf.yield %1: tensor + } else { + %1 = "test.munge_tensor"(%tensor) : (tensor) -> (tensor) + scf.yield %1 : tensor + } + return %0 : tensor +} + +// CHECK-LABEL: func @for_correct_recursive_legalization_behavior( +// CHECK-SAME: %[[TENSOR:.*]]: tensor, +// CHECK-SAME: %[[INDEX:.*]]: index) -> tensor { +// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref +// CHECK: %[[RESULT:.*]] = scf.for %[[IV:.*]] = %[[INDEX]] to %[[INDEX]] step %[[INDEX]] iter_args(%[[MEMREF_ITER:.*]] = %[[MEMREF]]) -> (memref) { +// CHECK: %[[TENSOR_ITER:.*]] = tensor_load %[[MEMREF_ITER]] : memref +// CHECK: %[[TENSOR_MUNGED:.*]] = "test.munge_tensor"(%[[TENSOR_ITER]]) : (tensor) -> tensor +// CHECK: %[[MEMREF_MUNGED:.*]] = tensor_to_memref %[[TENSOR_MUNGED]] : memref +// CHECK: scf.yield %[[MEMREF_MUNGED]] : memref +// CHECK: } +// CHECK: %[[TENSOR:.*]] = tensor_load %[[RESULT:.*]] : memref +// CHECK: return %[[TENSOR]] : tensor +// CHECK: } +func @for_correct_recursive_legalization_behavior(%arg0: tensor, %index: index) -> tensor { + %ret = scf.for %iv = %index to %index step %index iter_args(%iter = %arg0) -> tensor { + %0 = "test.munge_tensor"(%iter) : (tensor) -> (tensor) + scf.yield %0 : tensor + } + return %ret : tensor +}