diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -15,7 +15,6 @@ #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/DialectConversion.h" @@ -70,18 +69,29 @@ upperBounds.push_back(upperBound); } - // Generate tensor elements with a parallel loop. - rewriter.create( - loc, lowerBounds, upperBounds, steps, - [&](OpBuilder &b, Location loc, ValueRange ivs) { - BlockAndValueMapping mapping; - mapping.map(op.body().getArguments(), ivs); - for (auto &nestedOp : op.getBody()->without_terminator()) - b.clone(nestedOp, mapping); - auto yieldOp = cast(op.getBody()->getTerminator()); - b.create(loc, mapping.lookup(yieldOp.value()), result, ivs); - b.create(loc); - }); + // Generate tensor elements with a parallel loop that stores into + // each element of the resulting memref. + // + // This is a bit tricky. We cannot simply clone the ops because when an op + // is cloned, it must be legalized. However, we want to allow arbitrary ops + // in the body that we don't necessarily have legalization patterns for as + // part of this dialect conversion invocation. + // + // To accomplish this, we use mergeBlockBefore to "move" this op's body + // into the scf.parallel's body. + auto parallel = + rewriter.create(loc, lowerBounds, upperBounds, steps); + Block *parallelBody = parallel.getBody(); + rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(), + parallelBody->getArguments()); + // Replace the inlined yield op with a store op. The scf.parallel's builder + // already populated an scf.yield at the end, so we don't need to worry + // about creating that. + Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); + rewriter.setInsertionPointAfter(elementYield); + rewriter.replaceOpWithNewOp(elementYield, + elementYield->getOperands()[0], result, + parallelBody->getArguments()); rewriter.replaceOp(op, {result}); return success(); @@ -168,7 +178,6 @@ target.addLegalDialect(); target.addLegalDialect(); - target.addLegalDialect(); populateStdBufferizePatterns(context, typeConverter, patterns); target.addIllegalOp } -// The dynamic_tensor_from_elements op clones each op in its body. -// Make sure that regions nested within such ops are recursively converted. -// CHECK-LABEL: func @recursively_convert_cloned_regions -func @recursively_convert_cloned_regions(%arg0: tensor<*xf32>, %arg1: index, %arg2: i1) -> tensor { - %tensor = dynamic_tensor_from_elements %arg1 { +// The dynamic_tensor_from_elements op needs to put its body into the +// resulting scf.parallel. To handle unknown ops in the body, it cannot clone +// the body because that would require the cloned ops to be legalized +// immediately, which is usually not possible since they might be from various +// other dialects. +// +// CHECK-LABEL: func @unknown_ops_in_body +func @unknown_ops_in_body(%arg0: index) -> tensor { + // CHECK-NOT: dynamic_tensor_from_elements + %tensor = dynamic_tensor_from_elements %arg0 { ^bb0(%iv: index): - %48 = scf.if %arg2 -> (index) { - scf.yield %iv : index - } else { - // CHECK-NOT: dim{{.*}}tensor - %50 = dim %arg0, %iv : tensor<*xf32> - scf.yield %50 : index - } - yield %48 : index + // CHECK: test.source + %0 = "test.source"() : () -> index + yield %0 : index } : tensor return %tensor : tensor }