diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" @@ -228,6 +229,65 @@ } }; +/// Bufferization of tensor.generate. +struct GenerateOpInterface + : public BufferizableOpInterface::ExternalModel { + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationState &state) const { + auto generateOp = cast(op); + + // Allocate memory. + Location loc = op->getLoc(); + MemRefType memrefType = + getContiguousMemRefType(generateOp.getType().cast()); + FailureOr maybeResult = + createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(), + /*deallocMemref=*/state.getOptions().createDeallocs, + state.getOptions()); + if (failed(maybeResult)) + return failure(); + Value result = *maybeResult; + + // Collect loop bounds. + int64_t rank = memrefType.getRank(); + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + SmallVector lowerBounds(rank, zero); + SmallVector steps(rank, one); + SmallVector upperBounds; + int nextDynamicIndex = 0; + for (int i = 0; i < rank; i++) { + Value upperBound = memrefType.isDynamicDim(i) + ? generateOp.dynamicExtents()[nextDynamicIndex++] + : rewriter.create( + loc, memrefType.getDimSize(i)); + upperBounds.push_back(upperBound); + } + + // Generate tensor elements with a parallel loop that stores into + // each element of the resulting memref. We use mergeBlockBefore to "move" + // this op's body into the scf.parallel's body. + auto parallel = + rewriter.create(loc, lowerBounds, upperBounds, steps); + Block *parallelBody = parallel.getBody(); + rewriter.mergeBlockBefore(generateOp.getBody(), + parallelBody->getTerminator(), + parallelBody->getArguments()); + // Replace the inlined yield op with a store op. The scf.parallel's builder + // already populated an scf.yield at the end, so we don't need to worry + // about creating that. + Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); + rewriter.setInsertionPointAfter(elementYield); + rewriter.replaceOpWithNewOp( + elementYield, elementYield->getOperands()[0], result, + parallelBody->getArguments()); + + replaceOpWithBufferizedValues(rewriter, op, result); + return success(); + } +}; + /// Bufferization of tensor.insert. Replace with memref.store. struct InsertOpInterface : public BufferizableOpInterface::ExternalModel(); registry.addOpInterface(); registry.addOpInterface(); + registry.addOpInterface(); registry.addOpInterface(); registry.addOpInterface(); registry.addOpInterface(); diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -1359,3 +1359,23 @@ // CHECK: return %[[r]] : index return %0 : index } + +// ----- + +// CHECK-LABEL: func @tensor_generate_static_and_dynamic( +// CHECK-SAME: %[[arg0:.*]]: index +func @tensor_generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> { + // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index + // CHECK: %[[alloc:.*]] = memref.alloc(%[[arg0]]) {{.*}} : memref<16x?xindex> + // CHECK: scf.parallel (%[[arg1:.*]], %[[arg2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c16]], %[[arg0]]) {{.*}} { + %result = tensor.generate %arg0 { + ^bb0(%i: index, %j: index): + %sum = arith.addi %i, %j : index + // CHECK: memref.store {{.*}}, %[[alloc]][%[[arg1]], %[[arg2]]] + // CHECK: scf.yield + tensor.yield %sum : index + } : tensor<16x?xindex> + // CHECK: } + return %result : tensor<16x?xindex> +}