diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -34,6 +34,16 @@ // Helper methods. //===----------------------------------------------------------------------===// +/// Reorders stored dimension to original dimension. +static unsigned toOrig(const SparseTensorEncodingAttr &enc, unsigned i) { + auto order = enc.getDimOrdering(); + if (order) { + assert(order.isPermutation()); + return order.getDimPosition(i); + } + return i; +} + /// Reorders original dimension to stored dimension. static unsigned toStored(const SparseTensorEncodingAttr &enc, unsigned i) { auto order = enc.getDimOrdering(); @@ -87,7 +97,7 @@ // tensor type. switch (enc.getDimLevelType()[r]) { case SparseTensorEncodingAttr::DimLevelType::Dense: - break; + break; // no fields case SparseTensorEncodingAttr::DimLevelType::Compressed: case SparseTensorEncodingAttr::DimLevelType::CompressedNu: case SparseTensorEncodingAttr::DimLevelType::CompressedNo: @@ -111,7 +121,7 @@ return TupleType::get(context, fields); } -// Returns field index for pointers (d), indices (d) for set field. +// Returns field index of sparse tensor type for pointers/indices, when set. static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) { auto enc = getSparseTensorEncoding(type); assert(enc); @@ -161,6 +171,94 @@ builder.getIntegerAttr(indexType, field)); } +/// Creates tuple. +static Value createTupleMake(OpBuilder &builder, Location loc, Type type, + ValueRange values) { + return builder.create(loc, type, values); +} + +/// Create allocation operation. +static Value createAllocation(OpBuilder &builder, Location loc, Type type, + Value sz) { + auto memType = MemRefType::get({ShapedType::kDynamicSize}, type); + return builder.create(loc, memType, sz); +} + +/// Creates allocation tuple for sparse tensor type. +/// +/// TODO: for efficiency, we will need heuristis to make educated guesses +/// on the required final sizes; also, we will need an improved +/// memory allocation scheme with capacity and reallocation +/// +static Value createAllocTuple(OpBuilder &builder, Location loc, Type type, + ValueRange dynSizes) { + auto enc = getSparseTensorEncoding(type); + assert(enc); + // Construct the basic types. + unsigned idxWidth = enc.getIndexBitWidth(); + unsigned ptrWidth = enc.getPointerBitWidth(); + RankedTensorType rType = type.cast(); + Type indexType = builder.getIndexType(); + Type idxType = idxWidth ? builder.getIntegerType(idxWidth) : indexType; + Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType; + Type eltType = rType.getElementType(); + // Build the allocation tuple, using heuristics for pre-allocation. + auto shape = rType.getShape(); + unsigned rank = shape.size(); + SmallVector fields; + bool allDense = true; + Value one = constantIndex(builder, loc, 1); + Value linear = one; + Value heuristic = one; // FIX, see TODO above + // Build original sizes. + SmallVector sizes; + for (unsigned r = 0, o = 0; r < rank; r++) { + if (ShapedType::isDynamic(shape[r])) + sizes.push_back(dynSizes[o++]); + else + sizes.push_back(constantIndex(builder, loc, shape[r])); + } + // The dimSizes array. + Value dimSizes = + builder.create(loc, MemRefType::get({rank}, indexType)); + fields.push_back(dimSizes); + // Per-dimension storage. + for (unsigned r = 0; r < rank; r++) { + // Get the original dimension (ro) for the current stored dimension. + unsigned ro = toOrig(enc, r); + builder.create(loc, sizes[ro], dimSizes, + constantIndex(builder, loc, r)); + linear = builder.create(loc, linear, sizes[ro]); + // Allocate fiels. + switch (enc.getDimLevelType()[r]) { + case SparseTensorEncodingAttr::DimLevelType::Dense: + break; // no fields + case SparseTensorEncodingAttr::DimLevelType::Compressed: + case SparseTensorEncodingAttr::DimLevelType::CompressedNu: + case SparseTensorEncodingAttr::DimLevelType::CompressedNo: + case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo: + fields.push_back(createAllocation(builder, loc, ptrType, heuristic)); + fields.push_back(createAllocation(builder, loc, idxType, heuristic)); + allDense = false; + break; + case SparseTensorEncodingAttr::DimLevelType::Singleton: + case SparseTensorEncodingAttr::DimLevelType::SingletonNu: + case SparseTensorEncodingAttr::DimLevelType::SingletonNo: + case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo: + fields.push_back(createAllocation(builder, loc, idxType, heuristic)); + allDense = false; + break; + } + } + // The values array. For all-dense, the full length is required. + // In all other case, we resort to the heuristical initial value. + Value valuesSz = allDense ? linear : heuristic; + fields.push_back(createAllocation(builder, loc, eltType, valuesSz)); + // Construct tuple allocation. + Type tupleType = *convertSparseTensorType(type); + return createTupleMake(builder, loc, tupleType, fields); +} + /// Returns integral constant, if defined. static Optional getConstantInt(Value val) { if (auto constantOp = val.getDefiningOp()) @@ -233,6 +331,28 @@ } }; +/// Sparse codgen rule for the alloc operator. +class SparseTensorAllocConverter + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType resType = op.getType(); + auto enc = getSparseTensorEncoding(resType); + if (!enc) + return failure(); + if (op.getCopy()) + return rewriter.notifyMatchFailure(op, "tensor copy not implemented"); + // Construct allocation tuple. + Value tuple = createAllocTuple(rewriter, op->getLoc(), resType, + adaptor.getOperands()); + rewriter.replaceOp(op, tuple); + return success(); + } +}; + /// Sparse codegen rule for the dealloc operator. class SparseTensorDeallocConverter : public OpConversionPattern { @@ -311,6 +431,22 @@ } }; +/// Sparse codegen rule for tensor rematerialization. +class SparseTensorLoadConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getHasInserts()) { + // Finalize any pending insertions. + // TODO: implement + } + rewriter.replaceOp(op, adaptor.getOperands()); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -331,7 +467,8 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add( + SparseTensorAllocConverter, SparseTensorDeallocConverter, + SparseToPointersConverter, SparseToIndicesConverter, + SparseToValuesConverter, SparseTensorLoadConverter>( typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -156,7 +156,7 @@ ConversionTarget target(*ctx); // Almost everything in the sparse dialect must go! target.addIllegalDialect(); - target.addLegalOp(); + target.addLegalOp(); // All dynamic rules below accept new function, call, return, and various // tensor and bufferization operations as legal output of the rewriting // provided that all sparse tensor types have been fully rewritten. @@ -169,6 +169,10 @@ target.addDynamicallyLegalOp([&](func::ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); + target.addDynamicallyLegalOp( + [&](bufferization::AllocTensorOp op) { + return converter.isLegal(op.getType()); + }); target.addDynamicallyLegalOp( [&](bufferization::DeallocTensorOp op) { return converter.isLegal(op.getTensor().getType()); diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize --cse | FileCheck %s --check-prefix=CHECK-CODEGEN -// RUN: mlir-opt %s --sparse-tensor-codegen --sparse-tensor-storage-expansion --canonicalize --cse | FileCheck %s --check-prefix=CHECK-STORAGE - +// FIXME: +// R_U_N: mlir-opt %s --sparse-tensor-codegen --sparse-tensor-storage-expansion --canonicalize --cse | FileCheck %s --check-prefix=CHECK-STORAGE #SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], @@ -26,6 +26,11 @@ pointerBitWidth = 32 }> +#CSC = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + dimOrdering = affine_map<(i, j) -> (j, i)> +}> + #DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], indexBitWidth = 64, @@ -45,7 +50,7 @@ // CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<1xindex>, // CHECK-STORAGE-SAME: %[[A1:.*1]]: memref, // CHECK-STORAGE-SAME: %[[A2:.*2]]: memref, -// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref) +// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref) // CHECK-STORAGE: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] : memref<1xindex>, memref, memref, memref func.func @sparse_nop(%arg0: tensor) -> tensor { return %arg0 : tensor @@ -59,7 +64,7 @@ // CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<1xindex>, // CHECK-STORAGE-SAME: %[[A1:.*1]]: memref, // CHECK-STORAGE-SAME: %[[A2:.*2]]: memref, -// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref) +// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref) // CHECK-STORAGE: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] : memref<1xindex>, memref, memref, memref func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor { %0 = tensor.cast %arg0 : tensor<64xf32, #SparseVector> to tensor @@ -72,7 +77,7 @@ // // CHECK-STORAGE-LABEL: func @sparse_nop_cast_3d( // CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref) +// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref) // CHECK-STORAGE: return %[[A0]], %[[A1]] : memref<3xindex>, memref func.func @sparse_nop_cast_3d(%arg0: tensor<10x20x30xf32, #Dense3D>) -> tensor { %0 = tensor.cast %arg0 : tensor<10x20x30xf32, #Dense3D> to tensor @@ -142,7 +147,7 @@ // // CHECK-STORAGE-LABEL: func @sparse_dense_3d( // CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref) +// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref) // CHECK-STORAGE: %[[C:.*]] = arith.constant 20 : index // CHECK-STORAGE: return %[[C]] : index func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index { @@ -165,7 +170,7 @@ // // CHECK-STORAGE-LABEL: func @sparse_dense_3d_dyn( // CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref) +// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref) // CHECK-STORAGE: %[[C:.*]] = arith.constant 2 : index // CHECK-STORAGE: %[[L:.*]] = memref.load %[[A0]][%[[C]]] : memref<3xindex> // CHECK-STORAGE: return %[[L]] : index @@ -186,7 +191,7 @@ // CHECK-STORAGE-SAME: %[[A2:.*2]]: memref, // CHECK-STORAGE-SAME: %[[A3:.*3]]: memref, // CHECK-STORAGE-SAME: %[[A4:.*4]]: memref, -// CHECK-STORAGE-SAME: %[[A5:.*5]]: memref) +// CHECK-STORAGE-SAME: %[[A5:.*5]]: memref) // CHECK-STORAGE: return %[[A3]] : memref func.func @sparse_pointers_dcsr(%arg0: tensor) -> memref { %c = arith.constant 1 : index @@ -205,7 +210,7 @@ // CHECK-STORAGE-SAME: %[[A2:.*2]]: memref, // CHECK-STORAGE-SAME: %[[A3:.*3]]: memref, // CHECK-STORAGE-SAME: %[[A4:.*4]]: memref, -// CHECK-STORAGE-SAME: %[[A5:.*5]]: memref) +// CHECK-STORAGE-SAME: %[[A5:.*5]]: memref) // CHECK-STORAGE: return %[[A4]] : memref func.func @sparse_indices_dcsr(%arg0: tensor) -> memref { %c = arith.constant 1 : index @@ -224,7 +229,7 @@ // CHECK-STORAGE-SAME: %[[A2:.*2]]: memref, // CHECK-STORAGE-SAME: %[[A3:.*3]]: memref, // CHECK-STORAGE-SAME: %[[A4:.*4]]: memref, -// CHECK-STORAGE-SAME: %[[A5:.*5]]: memref) +// CHECK-STORAGE-SAME: %[[A5:.*5]]: memref) // CHECK-STORAGE: return %[[A5]] : memref func.func @sparse_values_dcsr(%arg0: tensor) -> memref { %0 = sparse_tensor.values %arg0 : tensor to memref @@ -257,3 +262,46 @@ bufferization.dealloc_tensor %arg0 : tensor return } + +// CHECK-CODEGEN-LABEL: func @sparse_alloc_csc( +// CHECK-CODEGEN-SAME: %[[A:.*]]: index) +// CHECK-CODEGEN-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-CODEGEN-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-CODEGEN-DAG: %[[C10:.*]] = arith.constant 10 : index +// CHECK-CODEGEN: %[[T0:.*]] = memref.alloc() : memref<2xindex> +// CHECK-CODEGEN: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<2xindex> +// CHECK-CODEGEN: memref.store %[[C10]], %[[T0]][%[[C1]]] : memref<2xindex> +// CHECK-CODEGEN: %[[T1:.*]] = memref.alloc() : memref<1xindex> +// CHECK-CODEGEN: %[[T2:.*]] = memref.cast %[[T1]] : memref<1xindex> to memref +// CHECK-CODEGEN: %[[T3:.*]] = memref.alloc() : memref<1xindex> +// CHECK-CODEGEN: %[[T4:.*]] = memref.cast %[[T3]] : memref<1xindex> to memref +// CHECK-CODEGEN: %[[T5:.*]] = memref.alloc() : memref<1xf64> +// CHECK-CODEGEN: %[[T6:.*]] = memref.cast %[[T5]] : memref<1xf64> to memref +// CHECK-CODEGEN: %[[T:.*]] = sparse_tensor.storage(%[[T0]], %[[T2]], %[[T4]], %[[T6]]) +// CHECK-CODEGEN: return %[[T]] : tuple, memref, memref, memref> +func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> { + %0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSC> + %1 = sparse_tensor.load %0 : tensor<10x?xf64, #CSC> + return %1 : tensor<10x?xf64, #CSC> +} + +// CHECK-CODEGEN-LABEL: func @sparse_alloc_3d() -> tuple, memref> +// CHECK-CODEGEN-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-CODEGEN-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-CODEGEN-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-CODEGEN-DAG: %[[C10:.*]] = arith.constant 10 : index +// CHECK-CODEGEN-DAG: %[[C20:.*]] = arith.constant 20 : index +// CHECK-CODEGEN-DAG: %[[C30:.*]] = arith.constant 30 : index +// CHECK-CODEGEN: %[[A0:.*]] = memref.alloc() : memref<3xindex> +// CHECK-CODEGEN: memref.store %[[C30]], %[[A0]][%[[C0]]] : memref<3xindex> +// CHECK-CODEGEN: memref.store %[[C10]], %[[A0]][%[[C1]]] : memref<3xindex> +// CHECK-CODEGEN: memref.store %[[C20]], %[[A0]][%[[C2]]] : memref<3xindex> +// CHECK-CODEGEN: %[[A:.*]] = memref.alloc() : memref<6000xf64> +// CHECK-CODEGEN: %[[A1:.*]] = memref.cast %[[A]] : memref<6000xf64> to memref +// CHECK-CODEGEN: %[[T:.*]] = sparse_tensor.storage(%[[A0]], %[[A1]]) +// CHECK-CODEGEN: return %[[T]] : tuple, memref> +func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> { + %0 = bufferization.alloc_tensor() : tensor<10x20x30xf64, #Dense3D> + %1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D> + return %1 : tensor<10x20x30xf64, #Dense3D> +}