diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -101,6 +101,42 @@ .getResult(); } +/// Returns field index of sparse tensor type for pointers/indices, when set. +static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) { + auto enc = getSparseTensorEncoding(type); + assert(enc); + RankedTensorType rType = type.cast(); + unsigned field = 2; // start past sizes + unsigned ptr = 0; + unsigned idx = 0; + for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) { + switch (enc.getDimLevelType()[r]) { + case SparseTensorEncodingAttr::DimLevelType::Dense: + break; // no fields + case SparseTensorEncodingAttr::DimLevelType::Compressed: + case SparseTensorEncodingAttr::DimLevelType::CompressedNu: + case SparseTensorEncodingAttr::DimLevelType::CompressedNo: + case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo: + if (ptr++ == ptrDim) + return field; + field++; + if (idx++ == idxDim) + return field; + field++; + break; + case SparseTensorEncodingAttr::DimLevelType::Singleton: + case SparseTensorEncodingAttr::DimLevelType::SingletonNu: + case SparseTensorEncodingAttr::DimLevelType::SingletonNo: + case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo: + if (idx++ == idxDim) + return field; + field++; + break; + } + } + return field + 1; // return values field index +} + /// Maps a sparse tensor type to the appropriate compounded buffers. static Optional convertSparseTensorType(Type type, SmallVectorImpl &fields) { @@ -118,10 +154,13 @@ Type eltType = rType.getElementType(); // // Sparse tensor storage for rank-dimensional tensor is organized as a - // single compound type with the following fields: + // single compound type with the following fields. Note that every + // memref with ? size actualy behaves as a "vector", i.e. the stored + // size is the capacity and the used size resides in the memSizes array. // // struct { // memref dimSizes ; size in each dimension + // memref memSizes ; sizes of ptrs/inds/values // ; per-dimension d: // ; if dense: // @@ -136,6 +175,9 @@ unsigned rank = rType.getShape().size(); // The dimSizes array. fields.push_back(MemRefType::get({rank}, indexType)); + // The memSizes array. + unsigned lastField = getFieldIndex(type, -1, -1); + fields.push_back(MemRefType::get({lastField - 2}, indexType)); // Per-dimension storage. for (unsigned r = 0; r < rank; r++) { // Dimension level types apply in order to the reordered dimension. @@ -162,46 +204,10 @@ } // The values array. fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, eltType)); + assert(fields.size() == lastField); return success(); } -// Returns field index of sparse tensor type for pointers/indices, when set. -static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) { - auto enc = getSparseTensorEncoding(type); - assert(enc); - RankedTensorType rType = type.cast(); - unsigned field = 1; // start at DimSizes; - unsigned ptr = 0; - unsigned idx = 0; - for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) { - switch (enc.getDimLevelType()[r]) { - case SparseTensorEncodingAttr::DimLevelType::Dense: - break; // no fields - case SparseTensorEncodingAttr::DimLevelType::Compressed: - case SparseTensorEncodingAttr::DimLevelType::CompressedNu: - case SparseTensorEncodingAttr::DimLevelType::CompressedNo: - case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo: - if (ptr++ == ptrDim) - return field; - field++; - if (idx++ == idxDim) - return field; - field++; - break; - case SparseTensorEncodingAttr::DimLevelType::Singleton: - case SparseTensorEncodingAttr::DimLevelType::SingletonNu: - case SparseTensorEncodingAttr::DimLevelType::SingletonNo: - case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo: - if (idx++ == idxDim) - return field; - field++; - break; - } - } - llvm_unreachable("failed to find ptr/idx field index"); - return -1; -} - /// Create allocation operation. static Value createAllocation(OpBuilder &builder, Location loc, Type type, Value sz) { @@ -209,11 +215,12 @@ return builder.create(loc, memType, sz); } -/// Creates allocation for each field in sparse tensor type. +/// Creates allocation for each field in sparse tensor type. Note that +/// for all dynamic memrefs, the memory size is really the capacity of +/// the "vector", while the actual size resides in the sizes array. /// /// TODO: for efficiency, we will need heuristis to make educated guesses -/// on the required final sizes; also, we will need an improved -/// memory allocation scheme with capacity and reallocation +/// on the required capacities /// static void createAllocFields(OpBuilder &builder, Location loc, Type type, ValueRange dynSizes, @@ -246,6 +253,11 @@ Value dimSizes = builder.create(loc, MemRefType::get({rank}, indexType)); fields.push_back(dimSizes); + // The sizes array. + unsigned lastField = getFieldIndex(type, -1, -1); + Value memSizes = builder.create( + loc, MemRefType::get({lastField - 2}, indexType)); + fields.push_back(memSizes); // Per-dimension storage. for (unsigned r = 0; r < rank; r++) { // Get the original dimension (ro) for the current stored dimension. @@ -278,6 +290,16 @@ // In all other case, we resort to the heuristical initial value. Value valuesSz = allDense ? linear : heuristic; fields.push_back(createAllocation(builder, loc, eltType, valuesSz)); + // Reset memSizes to zero. + if (allDense) + builder.create( + loc, valuesSz, memSizes, + constantIndex(builder, loc, 0)); // TODO: avoid memSizes in this case? + else + builder.create( + loc, ValueRange{constantZero(builder, loc, indexType)}, + ValueRange{memSizes}); + assert(fields.size() == lastField); } //===----------------------------------------------------------------------===// @@ -467,28 +489,6 @@ } }; -/// Base class for getter-like operations, e.g., to_indices, to_pointers. -template -class SparseGetterOpConverter : public OpConversionPattern { -public: - using OpAdaptor = typename SourceOp::Adaptor; - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(SourceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Replace the requested pointer access with corresponding field. - // The cast_op is inserted by type converter to intermix 1:N type - // conversion. - auto tuple = llvm::cast( - adaptor.getTensor().getDefiningOp()); - unsigned idx = Base::getIndexForOp(tuple, op); - auto fields = tuple.getInputs(); - assert(idx < fields.size()); - rewriter.replaceOp(op, fields[idx]); - return success(); - } -}; - /// Sparse codegen rule for the expand op. class SparseExpandConverter : public OpConversionPattern { public: @@ -543,6 +543,28 @@ } }; +/// Base class for getter-like operations, e.g., to_indices, to_pointers. +template +class SparseGetterOpConverter : public OpConversionPattern { +public: + using OpAdaptor = typename SourceOp::Adaptor; + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Replace the requested pointer access with corresponding field. + // The cast_op is inserted by type converter to intermix 1:N type + // conversion. + auto tuple = llvm::cast( + adaptor.getTensor().getDefiningOp()); + unsigned idx = Base::getIndexForOp(tuple, op); + auto fields = tuple.getInputs(); + assert(idx < fields.size()); + rewriter.replaceOp(op, fields[idx]); + return success(); + } +}; + /// Sparse codegen rule for pointer accesses. class SparseToPointersConverter : public SparseGetterOpConverter { @@ -602,9 +624,9 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add( + SparseCastConverter, SparseTensorAllocConverter, + SparseTensorDeallocConverter, SparseTensorLoadConverter, + SparseExpandConverter, SparseToPointersConverter, + SparseToIndicesConverter, SparseToValuesConverter>( typeConverter, patterns.getContext()); } diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -42,24 +42,27 @@ // CHECK-LABEL: func @sparse_nop( // CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref) -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] : memref<1xindex>, memref, memref, memref +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref) +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : memref<1xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_nop(%arg0: tensor) -> tensor { return %arg0 : tensor } // CHECK-LABEL: func @sparse_nop_multi_ret( // CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref<1xindex>, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref, -// CHECK-SAME: %[[A7:.*7]]: memref) -> -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]] +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref<1xindex>, +// CHECK-SAME: %[[A6:.*6]]: memref<3xindex>, +// CHECK-SAME: %[[A7:.*7]]: memref, +// CHECK-SAME: %[[A8:.*8]]: memref, +// CHECK-SAME: %[[A9:.*9]]: memref) -> +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]] func.func @sparse_nop_multi_ret(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { @@ -68,15 +71,17 @@ // CHECK-LABEL: func @sparse_nop_call( // CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref<1xindex>, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref, -// CHECK-SAME: %[[A7:.*7]]: memref) -// CHECK: %[[T0:.*]]:8 = call @sparse_nop_multi_ret(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]]) -// CHECK: return %[[T0]]#0, %[[T0]]#1, %[[T0]]#2, %[[T0]]#3, %[[T0]]#4, %[[T0]]#5, %[[T0]]#6, %[[T0]]#7 +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref<1xindex>, +// CHECK-SAME: %[[A6:.*6]]: memref<3xindex>, +// CHECK-SAME: %[[A7:.*7]]: memref, +// CHECK-SAME: %[[A8:.*8]]: memref, +// CHECK-SAME: %[[A9:.*9]]: memref) +// CHECK: %[[T0:.*]]:10 = call @sparse_nop_multi_ret(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]]) +// CHECK: return %[[T0]]#0, %[[T0]]#1, %[[T0]]#2, %[[T0]]#3, %[[T0]]#4, %[[T0]]#5, %[[T0]]#6, %[[T0]]#7, %[[T0]]#8, %[[T0]]#9 func.func @sparse_nop_call(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { @@ -86,67 +91,67 @@ return %1, %2: tensor, tensor } -// // CHECK-LABEL: func @sparse_nop_cast( // CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref) -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] : memref<1xindex>, memref, memref, memref +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref) +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : memref<1xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor { %0 = tensor.cast %arg0 : tensor<64xf32, #SparseVector> to tensor return %0 : tensor } -// // CHECK-LABEL: func @sparse_nop_cast_3d( // CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref) -// CHECK: return %[[A0]], %[[A1]] : memref<3xindex>, memref +// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref) +// CHECK: return %[[A0]], %[[A1]], %[[A2]] : memref<3xindex>, memref<1xindex>, memref func.func @sparse_nop_cast_3d(%arg0: tensor<10x20x30xf32, #Dense3D>) -> tensor { %0 = tensor.cast %arg0 : tensor<10x20x30xf32, #Dense3D> to tensor return %0 : tensor } -// // CHECK-LABEL: func @sparse_dense_2d( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref) { +// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref) // CHECK: return func.func @sparse_dense_2d(%arg0: tensor) { return } -// // CHECK-LABEL: func @sparse_row( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref) { +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref) // CHECK: return func.func @sparse_row(%arg0: tensor) { return } -// // CHECK-LABEL: func @sparse_csr( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref) { +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref) // CHECK: return func.func @sparse_csr(%arg0: tensor) { return } -// // CHECK-LABEL: func @sparse_dcsr( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref) { +// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref) // CHECK: return func.func @sparse_dcsr(%arg0: tensor) { return @@ -156,10 +161,10 @@ // Querying for dimension 1 in the tensor type can immediately // fold using the original static dimension sizes. // -// // CHECK-LABEL: func @sparse_dense_3d( // CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref) +// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref) // CHECK: %[[C:.*]] = arith.constant 20 : index // CHECK: return %[[C]] : index func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index { @@ -173,10 +178,10 @@ // into querying for dimension 2 in the stored sparse tensor scheme, // since the latter honors the dimOrdering. // -// // CHECK-LABEL: func @sparse_dense_3d_dyn( // CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref) +// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref) // CHECK: %[[C:.*]] = arith.constant 2 : index // CHECK: %[[L:.*]] = memref.load %[[A0]][%[[C]]] : memref<3xindex> // CHECK: return %[[L]] : index @@ -186,115 +191,121 @@ return %0 : index } -// // CHECK-LABEL: func @sparse_pointers_dcsr( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref) -// CHECK: return %[[A3]] : memref +// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref) +// CHECK: return %[[A4]] : memref func.func @sparse_pointers_dcsr(%arg0: tensor) -> memref { %0 = sparse_tensor.pointers %arg0 { dimension = 1 : index } : tensor to memref return %0 : memref } -// // CHECK-LABEL: func @sparse_indices_dcsr( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref) -// CHECK: return %[[A4]] : memref +// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref) +// CHECK: return %[[A5]] : memref func.func @sparse_indices_dcsr(%arg0: tensor) -> memref { %0 = sparse_tensor.indices %arg0 { dimension = 1 : index } : tensor to memref return %0 : memref } -// // CHECK-LABEL: func @sparse_values_dcsr( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref) -// CHECK: return %[[A5]] : memref +// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref) +// CHECK: return %[[A6]] : memref func.func @sparse_values_dcsr(%arg0: tensor) -> memref { %0 = sparse_tensor.values %arg0 : tensor to memref return %0 : memref } -// // CHECK-LABEL: func @sparse_dealloc_csr( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref) { +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref) // CHECK: memref.dealloc %[[A0]] : memref<2xindex> -// CHECK: memref.dealloc %[[A1]] : memref -// CHECK: memref.dealloc %[[A2]] : memref -// CHECK: memref.dealloc %[[A3]] : memref +// CHECK: memref.dealloc %[[A1]] : memref<3xindex> +// CHECK: memref.dealloc %[[A2]] : memref +// CHECK: memref.dealloc %[[A3]] : memref +// CHECK: memref.dealloc %[[A4]] : memref // CHECK: return func.func @sparse_dealloc_csr(%arg0: tensor) { bufferization.dealloc_tensor %arg0 : tensor return } -// CHECK-LABEL: func @sparse_alloc_csc( -// CHECK-SAME: %[[A:.*]]: index) -> -// CHECK-SAME: memref<2xindex>, memref, memref, memref -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index -// CHECK: %[[T0:.*]] = memref.alloc() : memref<2xindex> -// CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<2xindex> -// CHECK: memref.store %[[C10]], %[[T0]][%[[C1]]] : memref<2xindex> -// CHECK: %[[T1:.*]] = memref.alloc() : memref<1xindex> -// CHECK: %[[T2:.*]] = memref.cast %[[T1]] : memref<1xindex> to memref -// CHECK: %[[T3:.*]] = memref.alloc() : memref<1xindex> -// CHECK: %[[T4:.*]] = memref.cast %[[T3]] : memref<1xindex> to memref -// CHECK: %[[T5:.*]] = memref.alloc() : memref<1xf64> -// CHECK: %[[T6:.*]] = memref.cast %[[T5]] : memref<1xf64> to memref -// CHECK: return %[[T0]], %[[T2]], %[[T4]], %[[T6]] +// CHECK-LABEL: func @sparse_alloc_csc( +// CHECK-SAME: %[[A:.*]]: index) -> +// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index +// CHECK: %[[T0:.*]] = memref.alloc() : memref<2xindex> +// CHECK: %[[T1:.*]] = memref.alloc() : memref<3xindex> +// CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<2xindex> +// CHECK: memref.store %[[C10]], %[[T0]][%[[C1]]] : memref<2xindex> +// CHECK: %[[T2:.*]] = memref.alloc() : memref<1xindex> +// CHECK: %[[T3:.*]] = memref.cast %[[T2]] : memref<1xindex> to memref +// CHECK: %[[T4:.*]] = memref.alloc() : memref<1xindex> +// CHECK: %[[T5:.*]] = memref.cast %[[T4]] : memref<1xindex> to memref +// CHECK: %[[T6:.*]] = memref.alloc() : memref<1xf64> +// CHECK: %[[T7:.*]] = memref.cast %[[T6]] : memref<1xf64> to memref +// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T1]] : memref<3xindex>) +// CHECK: return %[[T0]], %[[T1]], %[[T3]], %[[T5]], %[[T7]] func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> { %0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSC> %1 = sparse_tensor.load %0 : tensor<10x?xf64, #CSC> return %1 : tensor<10x?xf64, #CSC> } -// CHECK-LABEL: func @sparse_alloc_3d() -> -// CHECK-SAME: memref<3xindex>, memref -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index -// CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index -// CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index -// CHECK: %[[A0:.*]] = memref.alloc() : memref<3xindex> -// CHECK: memref.store %[[C30]], %[[A0]][%[[C0]]] : memref<3xindex> -// CHECK: memref.store %[[C10]], %[[A0]][%[[C1]]] : memref<3xindex> -// CHECK: memref.store %[[C20]], %[[A0]][%[[C2]]] : memref<3xindex> -// CHECK: %[[A:.*]] = memref.alloc() : memref<6000xf64> -// CHECK: %[[A1:.*]] = memref.cast %[[A]] : memref<6000xf64> to memref -// CHECK: return %[[A0]], %[[A1]] : memref<3xindex>, memref +// CHECK-LABEL: func @sparse_alloc_3d() -> +// CHECK-SAME: memref<3xindex>, memref<1xindex>, memref +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index +// CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index +// CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index +// CHECK-DAG: %[[C6000:.*]] = arith.constant 6000 : index +// CHECK: %[[A0:.*]] = memref.alloc() : memref<3xindex> +// CHECK: %[[A1:.*]] = memref.alloc() : memref<1xindex> +// CHECK: memref.store %[[C30]], %[[A0]][%[[C0]]] : memref<3xindex> +// CHECK: memref.store %[[C10]], %[[A0]][%[[C1]]] : memref<3xindex> +// CHECK: memref.store %[[C20]], %[[A0]][%[[C2]]] : memref<3xindex> +// CHECK: %[[A:.*]] = memref.alloc() : memref<6000xf64> +// CHECK: %[[A2:.*]] = memref.cast %[[A]] : memref<6000xf64> to memref +// CHECK: memref.store %[[C6000]], %[[A1]][%[[C0]]] : memref<1xindex> +// CHECK: return %[[A0]], %[[A1]], %[[A2]] : memref<3xindex>, memref<1xindex>, memref func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> { %0 = bufferization.alloc_tensor() : tensor<10x20x30xf64, #Dense3D> %1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D> return %1 : tensor<10x20x30xf64, #Dense3D> } -// CHECK-LABEL: func.func @sparse_expansion1() -// CHECK: %[[A:.*]] = memref.alloc() : memref<8xf64> -// CHECK: %[[B:.*]] = memref.alloc() : memref<8xi1> -// CHECK: %[[C:.*]] = memref.alloc() : memref<8xindex> -// CHECK: %[[D:.*]] = memref.cast %[[C]] : memref<8xindex> to memref -// CHECK-DAG: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref<8xf64>) -// CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<8xi1>) -// CHECK: return %[[D]] : memref +// CHECK-LABEL: func.func @sparse_expansion1() +// CHECK: %[[A:.*]] = memref.alloc() : memref<8xf64> +// CHECK: %[[B:.*]] = memref.alloc() : memref<8xi1> +// CHECK: %[[C:.*]] = memref.alloc() : memref<8xindex> +// CHECK: %[[D:.*]] = memref.cast %[[C]] : memref<8xindex> to memref +// CHECK-DAG: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref<8xf64>) +// CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<8xi1>) +// CHECK: return %[[D]] : memref func.func @sparse_expansion1() -> memref { %0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSR> %values, %filled, %added, %count = sparse_tensor.expand %0 @@ -302,14 +313,14 @@ return %added : memref } -// CHECK-LABEL: func.func @sparse_expansion2() -// CHECK: %[[A:.*]] = memref.alloc() : memref<4xf64> -// CHECK: %[[B:.*]] = memref.alloc() : memref<4xi1> -// CHECK: %[[C:.*]] = memref.alloc() : memref<4xindex> -// CHECK: %[[D:.*]] = memref.cast %[[C]] : memref<4xindex> to memref -// CHECK-DAG: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref<4xf64>) -// CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<4xi1>) -// CHECK: return %[[D]] : memref +// CHECK-LABEL: func.func @sparse_expansion2() +// CHECK: %[[A:.*]] = memref.alloc() : memref<4xf64> +// CHECK: %[[B:.*]] = memref.alloc() : memref<4xi1> +// CHECK: %[[C:.*]] = memref.alloc() : memref<4xindex> +// CHECK: %[[D:.*]] = memref.cast %[[C]] : memref<4xindex> to memref +// CHECK-DAG: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref<4xf64>) +// CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<4xi1>) +// CHECK: return %[[D]] : memref func.func @sparse_expansion2() -> memref { %0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSC> %values, %filled, %added, %count = sparse_tensor.expand %0 @@ -317,19 +328,19 @@ return %added : memref } -// CHECK-LABEL: func.func @sparse_expansion3( -// CHECK-SAME: %[[D0:.*]]: index, -// CHECK-SAME: %{{.*}}: index) -> memref { -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[S0:.*]] = memref.alloc() : memref<2xindex> -// CHECK: memref.store %[[D0]], %[[S0]]{{\[}}%[[C1]]] : memref<2xindex> -// CHECK: %[[D1:.*]] = memref.load %[[S0]]{{\[}}%[[C1]]] : memref<2xindex> -// CHECK: %[[V:.*]] = memref.alloc(%[[D1]]) : memref -// CHECK: %[[B:.*]] = memref.alloc(%[[D1]]) : memref -// CHECK: %[[D:.*]] = memref.alloc(%[[D1]]) : memref -// CHECK: linalg.fill ins(%{{.*}} : f64) outs(%[[V]] : memref) -// CHECK: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref) -// CHECK: return %[[D]] : memref +// CHECK-LABEL: func.func @sparse_expansion3( +// CHECK-SAME: %[[D0:.*]]: index, +// CHECK-SAME: %{{.*}}: index) -> memref { +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[S0:.*]] = memref.alloc() : memref<2xindex> +// CHECK: memref.store %[[D0]], %[[S0]]{{\[}}%[[C1]]] : memref<2xindex> +// CHECK: %[[D1:.*]] = memref.load %[[S0]]{{\[}}%[[C1]]] : memref<2xindex> +// CHECK: %[[V:.*]] = memref.alloc(%[[D1]]) : memref +// CHECK: %[[B:.*]] = memref.alloc(%[[D1]]) : memref +// CHECK: %[[D:.*]] = memref.alloc(%[[D1]]) : memref +// CHECK: linalg.fill ins(%{{.*}} : f64) outs(%[[V]] : memref) +// CHECK: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref) +// CHECK: return %[[D]] : memref func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref { %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor %values, %filled, %added, %count = sparse_tensor.expand %0