diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1082,6 +1082,54 @@ } }; +/// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled +/// with a linalg.map. Similar to tensor.generate. +struct SplatOpInterface + : public BufferizableOpInterface::ExternalModel { + + bool bufferizesToAllocation(Operation *op, OpResult opResult) const { + return true; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + OpBuilder::InsertionGuard g(rewriter); + auto splatOp = cast(op); + + // Should the buffer be deallocated? + bool dealloc = shouldDeallocateOpResult( + cast(splatOp.getResult()), options); + + // TODO: Implement memory space for this op. + if (options.defaultMemorySpace != Attribute()) + return op->emitError("memory space not implemented yet"); + + // Allocate memory. + Location loc = op->getLoc(); + FailureOr tensorAlloc = + allocateTensorForShapedValue(rewriter, loc, splatOp.getResult(), + /*escape=*/!dealloc, options, + /*copy=*/false); + if (failed(tensorAlloc)) + return failure(); + + // Create linalg::MapOp. + auto tensorType = cast(tensorAlloc->getType()); + auto linalgOp = + rewriter.create(loc, tensorType, /*inputs=*/ValueRange(), + /*init=*/*tensorAlloc); + Block &linalgBody = linalgOp.getMapper().emplaceBlock(); + + // Create linalg::IndexOps. + rewriter.setInsertionPointToStart(&linalgBody); + rewriter.create(loc, splatOp.getInput()); + rewriter.replaceOp(splatOp, linalgOp.getResult()[0]); + + return success(); + } +}; + } // namespace } // namespace tensor } // namespace mlir @@ -1105,6 +1153,7 @@ *ctx); RankOp::attachInterface(*ctx); ReshapeOp::attachInterface(*ctx); + SplatOp::attachInterface(*ctx); // Load additional dialects of which ops may get created. ctx->loadDialect(); diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -582,3 +582,20 @@ // CHECK: return %[[r]] : tensor return %0 : tensor } + +// ----- + +// CHECK-LABEL: func @tensor.splat( +// CHECK-SAME: %[[F:.*]]: f32) +// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<10x2x4xf32> +// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]] +// CHECK: %[[MAPPED:.*]] = linalg.map +// CHECK: outs(%[[ALLOC_T]] : tensor<10x2x4xf32>) +// CHECK: linalg.yield %[[F]] +// CHECK: } +// CHECK: return %[[MAPPED]] : tensor<10x2x4xf32> +// CHECK: } +func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> { + %t = tensor.splat %f : tensor<10x2x4xf32> + return %t : tensor<10x2x4xf32> +}