diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -229,6 +229,82 @@ } }; +// Implements backtracking to traverse indices of the output buffer while +// iterating over op.elements(). +static void createStores(RewriterBase &rewriter, Location loc, int dim, + Value buffer, ArrayRef shape, + ArrayRef constants, + OperandRange::iterator &elementIt, + SmallVectorImpl &indices) { + if (dim == static_cast(shape.size()) - 1) { + for (int i = 0; i < shape.back(); ++i) { + indices.back() = constants[i]; + rewriter.create(loc, *elementIt, buffer, indices); + ++elementIt; + } + return; + } + for (int i = 0; i < shape[dim]; ++i) { + indices[dim] = constants[i]; + createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, + indices); + } +} + +/// Bufferization of tensor.from_elements. +struct FromElementsOpInterface + : public BufferizableOpInterface::ExternalModel { + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationState &state) const { + auto fromElementsOp = cast(op); + + // Allocate a buffer for the result. + Location loc = op->getLoc(); + auto tensorType = fromElementsOp.getType().cast(); + auto shape = tensorType.getShape(); + MemRefType resultType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + FailureOr maybeBuffer = + createAlloc(rewriter, loc, resultType, {}, + /*deallocMemref=*/state.getOptions().createDeallocs, + state.getOptions()); + if (failed(maybeBuffer)) + return failure(); + Value buffer = *maybeBuffer; + + // Case: tensor<0xelem_type>. + if (fromElementsOp.elements().empty()) { + replaceOpWithBufferizedValues(rewriter, op, buffer); + return success(); + } + + // Case: tensor. + if (shape.empty()) { + rewriter.create(loc, fromElementsOp.elements().front(), + buffer); + replaceOpWithBufferizedValues(rewriter, op, buffer); + return success(); + } + + // Create constants for the range of possible indices [0, max{shape_i}). + auto maxDim = *std::max_element(shape.begin(), shape.end()); + SmallVector constants; + constants.reserve(maxDim); + for (int i = 0; i < maxDim; ++i) + constants.push_back(rewriter.create(loc, i)); + + // Traverse all `elements` and create `memref.store` ops. + auto elementIt = fromElementsOp.elements().begin(); + SmallVector indices(tensorType.getRank(), constants[0]); + createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, + indices); + + replaceOpWithBufferizedValues(rewriter, op, buffer); + return success(); + } +}; + /// Bufferization of tensor.generate. struct GenerateOpInterface : public BufferizableOpInterface::ExternalModel(); registry.addOpInterface(); registry.addOpInterface(); + registry.addOpInterface(); registry.addOpInterface(); registry.addOpInterface(); registry.addOpInterface(); diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -1379,3 +1379,24 @@ // CHECK: } return %result : tensor<16x?xindex> } + +// ----- + +// CHECK-LABEL: func @tensor_from_elements_2d( +// CHECK-SAME: %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index +func @tensor_from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + // CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex> + // CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]] + // CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]] + // CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]] + // CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]] + // CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]] + // CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]] + %0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1 + : tensor<3x2xindex> + // CHECK: return %[[MEMREF]] + return %0 : tensor<3x2xindex> +}