diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -130,9 +130,11 @@ /// Sets up sparse tensor conversion rules. void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, - RewritePatternSet &patterns); + RewritePatternSet &patterns, + bool enableBufferInitialization); -std::unique_ptr createSparseTensorCodegenPass(); +std::unique_ptr +createSparseTensorCodegenPass(bool enableBufferInitialization = false); //===----------------------------------------------------------------------===// // The SparseTensorRewriting pass. diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -181,6 +181,10 @@ "scf::SCFDialect", "sparse_tensor::SparseTensorDialect", ]; + let options = [ + Option<"enableBufferInitialization", "enable-buffer-initialization", "bool", + "false", "Enable zero-initialization of the memory buffers">, + ]; } def SparseBufferRewrite : Pass<"sparse-buffer-rewrite", "ModuleOp"> { diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -64,7 +64,8 @@ pm.addPass(createSparseTensorConversionPass( options.sparseTensorConversionOptions())); else - pm.addPass(createSparseTensorCodegenPass()); + pm.addPass( + createSparseTensorCodegenPass(options.enableBufferInitialization)); pm.addPass(createSparseBufferRewritePass(options.enableBufferInitialization)); pm.addPass(createDenseBufferizationPass( getBufferizationOptions(/*analysisOnly=*/false))); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -287,9 +287,15 @@ /// Creates allocation operation. static Value createAllocation(OpBuilder &builder, Location loc, Type type, - Value sz) { + Value sz, bool enableInit) { auto memType = MemRefType::get({ShapedType::kDynamicSize}, type); - return builder.create(loc, memType, sz); + Value buffer = builder.create(loc, memType, sz); + if (enableInit) { + Value fillValue = + builder.create(loc, type, builder.getZeroAttr(type)); + builder.create(loc, fillValue, buffer); + } + return buffer; } /// Creates allocation for each field in sparse tensor type. Note that @@ -300,7 +306,7 @@ /// on the required capacities (see heuristic variable). /// static void createAllocFields(OpBuilder &builder, Location loc, Type type, - ValueRange dynSizes, + ValueRange dynSizes, bool enableInit, SmallVectorImpl &fields) { auto enc = getSparseTensorEncoding(type); assert(enc); @@ -334,16 +340,20 @@ // Per-dimension storage. for (unsigned r = 0; r < rank; r++) { if (isCompressedDim(rtp, r)) { - fields.push_back(createAllocation(builder, loc, ptrType, heuristic)); - fields.push_back(createAllocation(builder, loc, idxType, heuristic)); + fields.push_back( + createAllocation(builder, loc, ptrType, heuristic, enableInit)); + fields.push_back( + createAllocation(builder, loc, idxType, heuristic, enableInit)); } else if (isSingletonDim(rtp, r)) { - fields.push_back(createAllocation(builder, loc, idxType, heuristic)); + fields.push_back( + createAllocation(builder, loc, idxType, heuristic, enableInit)); } else { assert(isDenseDim(rtp, r)); // no fields } } // The values array. - fields.push_back(createAllocation(builder, loc, eltType, heuristic)); + fields.push_back( + createAllocation(builder, loc, eltType, heuristic, enableInit)); assert(fields.size() == lastField); // Initialize the storage scheme to an empty tensor. Initialized memSizes // to all zeros, sets the dimSizes to known values and gives all pointer @@ -685,6 +695,8 @@ : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; + SparseTensorAllocConverter(MLIRContext *context, bool enableInit) + : OpConversionPattern(context), enableBufferInitialization(enableInit) {} LogicalResult matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -698,11 +710,15 @@ // Construct allocation for each field. Location loc = op.getLoc(); SmallVector fields; - createAllocFields(rewriter, loc, resType, adaptor.getOperands(), fields); + createAllocFields(rewriter, loc, resType, adaptor.getOperands(), + enableBufferInitialization, fields); // Replace operation with resulting memrefs. rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields)); return success(); } + +private: + bool enableBufferInitialization; }; /// Sparse codegen rule for the dealloc operator. @@ -1014,8 +1030,9 @@ /// Populates the given patterns list with conversion rules required for /// the sparsification of linear algebra operations. -void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, - RewritePatternSet &patterns) { +void mlir::populateSparseTensorCodegenPatterns( + TypeConverter &typeConverter, RewritePatternSet &patterns, + bool enableBufferInitialization) { patterns.add( typeConverter, patterns.getContext()); + patterns.add(typeConverter, patterns.getContext(), + enableBufferInitialization); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -161,6 +161,9 @@ SparseTensorCodegenPass() = default; SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default; + SparseTensorCodegenPass(bool enableInit) { + enableBufferInitialization = enableInit; + } void runOnOperation() override { auto *ctx = &getContext(); @@ -203,7 +206,8 @@ converter); scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, target); - populateSparseTensorCodegenPatterns(converter, patterns); + populateSparseTensorCodegenPatterns(converter, patterns, + enableBufferInitialization); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); @@ -278,8 +282,9 @@ return std::make_unique(options); } -std::unique_ptr mlir::createSparseTensorCodegenPass() { - return std::make_unique(); +std::unique_ptr +mlir::createSparseTensorCodegenPass(bool enableBufferInitialization) { + return std::make_unique(enableBufferInitialization); } std::unique_ptr diff --git a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt %s --sparse-tensor-codegen=enable-buffer-initialization=true --canonicalize --cse | FileCheck %s + +#SV = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> + +// CHECK-LABEL: func @sparse_alloc_sparse_vector( +// CHECK-SAME: %[[A:.*]]: index) -> +// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[F0:.*]] = arith.constant 0.{{0*}}e+00 : f64 +// CHECK: %[[T0:.*]] = memref.alloc() : memref<1xindex> +// CHECK: %[[T1:.*]] = memref.alloc() : memref<3xindex> +// CHECK: %[[T2:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[T3:.*]] = memref.cast %[[T2]] : memref<16xindex> to memref +// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T2]] : memref<16xindex>) +// CHECK: %[[T4:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[T5:.*]] = memref.cast %[[T4]] : memref<16xindex> to memref +// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T4]] : memref<16xindex>) +// CHECK: %[[T6:.*]] = memref.alloc() : memref<16xf64> +// CHECK: %[[T7:.*]] = memref.cast %[[T6]] : memref<16xf64> to memref +// CHECK: linalg.fill ins(%[[F0]] : f64) outs(%[[T6]] : memref<16xf64>) +// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T1]] : memref<3xindex>) +// CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<1xindex> +// CHECK: %[[P0:.*]] = sparse_tensor.push_back %[[T1]], %[[T3]] +// CHECK: %[[P1:.*]] = sparse_tensor.push_back %[[T1]], %[[P0]] +// CHECK: return %[[T0]], %[[T1]], %[[P1]], %[[T5]], %[[T7]] : +func.func @sparse_alloc_sparse_vector(%arg0: index) -> tensor { + %0 = bufferization.alloc_tensor(%arg0) : tensor + %1 = sparse_tensor.load %0 : tensor + return %1 : tensor +}