diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td @@ -49,9 +49,9 @@ lattices drive actual sparse code generation, which consists of a relatively straightforward one-to-one mapping from iteration lattices to combinations of for-loops, while-loops, and if-statements. Sparse - tensor outputs that materialize uninitialized are handled with - insertions in pure lexicographical index order if all parallel loops - are outermost or using a 1-dimensional access pattern expansion + tensor outputs that materialize uninitialized are handled with direct + insertions if all parallel loops are outermost or insertions that + indirectly go through a 1-dimensional access pattern expansion (a.k.a. workspace) where feasible [Gustavson72,Bik96,Kjolstad19]. * [Bik96] Aart J.C. Bik. Compiler Support for Sparse Matrix Computations. diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -197,18 +197,25 @@ // as our sparse abstractions evolve. //===----------------------------------------------------------------------===// -def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>, +def SparseTensor_InsertOp : SparseTensor_Op<"insert", []>, Arguments<(ins AnySparseTensor:$tensor, StridedMemRefRankOf<[Index], [1]>:$indices, AnyType:$value)> { - string summary = "Inserts a value into given sparse tensor in lexicographical index order"; + string summary = "Inserts a value into given sparse tensor"; string description = [{ Inserts the given value at given indices into the underlying sparse storage format of the given tensor with the given indices. This operation can only be applied when a tensor materializes unintialized - with a `bufferization.alloc_tensor` operation, the insertions occur in - strict lexicographical index order, and the final tensor is constructed - with a `load` operation that has the `hasInserts` attribute set. + with a `bufferization.alloc_tensor` operation and the final tensor + is constructed with a `load` operation that has the `hasInserts` + attribute set. + + Properties in the sparse tensor type fully describe what kind + of insertion order is allowed. When all dimensions have "unique" + and "ordered" properties, for example, insertions should occur in + strict lexicographical index order. Other properties define + different insertion regimens. Inserting in a way contrary to + these properties results in undefined behavior. Note that this operation is "impure" in the sense that its behavior is solely defined by side-effects and not SSA values. The semantics @@ -217,7 +224,7 @@ Example: ```mlir - sparse_tensor.lex_insert %tensor, %indices, %val + sparse_tensor.insert %tensor, %indices, %val : tensor<1024x1024xf64, #CSR>, memref, memref ``` }]; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -1137,13 +1137,15 @@ } }; -/// Sparse conversion rule for inserting in lexicographic index order. -class SparseTensorLexInsertConverter : public OpConversionPattern { +/// Sparse conversion rule for the insertion operator. +class SparseTensorInsertConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(LexInsertOp op, OpAdaptor adaptor, + matchAndRewrite(InsertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // Note that the current regime only allows for strict lexicographic + // index order. Type elemTp = op.getTensor().getType().cast().getElementType(); SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; replaceOpWithFuncCall(rewriter, op, name, {}, adaptor.getOperands(), @@ -1432,7 +1434,7 @@ SparseTensorConcatConverter, SparseTensorAllocConverter, SparseTensorDeallocConverter, SparseTensorToPointersConverter, SparseTensorToIndicesConverter, SparseTensorToValuesConverter, - SparseTensorLoadConverter, SparseTensorLexInsertConverter, + SparseTensorLoadConverter, SparseTensorInsertConverter, SparseTensorExpandConverter, SparseTensorCompressConverter, SparseTensorOutConverter>(typeConverter, patterns.getContext()); diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -494,7 +494,7 @@ func.func @sparse_insert(%arg0: tensor<128xf32, #SparseVector>, %arg1: memref, %arg2: memref) { - sparse_tensor.lex_insert %arg0, %arg1, %arg2 : tensor<128xf32, #SparseVector>, memref, memref + sparse_tensor.insert %arg0, %arg1, %arg2 : tensor<128xf32, #SparseVector>, memref, memref return } diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -107,8 +107,8 @@ // ----- func.func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: memref, %arg2: f64) { - // expected-error@+1 {{'sparse_tensor.lex_insert' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}} - sparse_tensor.lex_insert %arg0, %arg1, %arg2 : tensor<128xf64>, memref, f64 + // expected-error@+1 {{'sparse_tensor.insert' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}} + sparse_tensor.insert %arg0, %arg1, %arg2 : tensor<128xf64>, memref, f64 return } diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -123,10 +123,10 @@ // CHECK-SAME: %[[A:.*]]: tensor<128xf64, #sparse_tensor.encoding<{{.*}}>>, // CHECK-SAME: %[[B:.*]]: memref, // CHECK-SAME: %[[C:.*]]: f64) { -// CHECK: sparse_tensor.lex_insert %[[A]], %[[B]], %[[C]] : tensor<128xf64, #{{.*}}>, memref, f64 +// CHECK: sparse_tensor.insert %[[A]], %[[B]], %[[C]] : tensor<128xf64, #{{.*}}>, memref, f64 // CHECK: return func.func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: memref, %arg2: f64) { - sparse_tensor.lex_insert %arg0, %arg1, %arg2 : tensor<128xf64, #SparseVector>, memref, f64 + sparse_tensor.insert %arg0, %arg1, %arg2 : tensor<128xf64, #SparseVector>, memref, f64 return } diff --git a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir @@ -375,7 +375,7 @@ // CHECK: %[[VAL_20:.*]] = math.sin %[[VAL_19]] : f64 // CHECK: %[[VAL_21:.*]] = math.tanh %[[VAL_20]] : f64 // CHECK: memref.store %[[VAL_21]], %[[BUF]][] : memref -// CHECK: sparse_tensor.lex_insert %[[VAL_4]], %[[VAL_8]], %[[BUF]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>, memref, memref +// CHECK: sparse_tensor.insert %[[VAL_4]], %[[VAL_8]], %[[BUF]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>, memref, memref // CHECK: } // CHECK: %[[VAL_22:.*]] = sparse_tensor.load %[[VAL_4]] hasInserts : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> // CHECK: return %[[VAL_22]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> @@ -419,7 +419,7 @@ // CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref> // CHECK: %[[VAL_15:.*]] = complex.div %[[VAL_14]], %[[VAL_3]] : complex // CHECK: memref.store %[[VAL_15]], %[[VAL_9]][] : memref> -// CHECK: sparse_tensor.lex_insert %[[VAL_4]], %[[VAL_8]], %[[VAL_9]] : tensor<32xcomplex, #sparse_tensor.encoding<{{.*}}>>, memref, memref> +// CHECK: sparse_tensor.insert %[[VAL_4]], %[[VAL_8]], %[[VAL_9]] : tensor<32xcomplex, #sparse_tensor.encoding<{{.*}}>>, memref, memref> // CHECK: } // CHECK: %[[VAL_16:.*]] = sparse_tensor.load %[[VAL_4]] hasInserts : tensor<32xcomplex, #sparse_tensor.encoding<{{.*}}>> // CHECK: return %[[VAL_16]] : tensor<32xcomplex, #sparse_tensor.encoding<{{.*}}>> diff --git a/mlir/test/Dialect/SparseTensor/sparse_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_index.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_index.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_index.mlir @@ -100,7 +100,7 @@ // CHECK: %[[VAL_25:.*]] = arith.muli %[[VAL_23]], %[[VAL_24]] : i64 // CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_22]], %[[VAL_25]] : i64 // CHECK: memref.store %[[VAL_26]], %[[BUF]][] : memref -// CHECK: sparse_tensor.lex_insert %[[VAL_6]], %[[VAL_12]], %[[BUF]] : tensor // CHECK: %[[VAL_19:.*]] = arith.mulf %[[VAL_18]], %[[VAL_1]] : f32 // CHECK: memref.store %[[VAL_19]], %[[BUF]][] : memref -// CHECK: sparse_tensor.lex_insert %[[VAL_7]], %[[VAL_11]], %[[BUF]] : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>> +// CHECK: sparse_tensor.insert %[[VAL_7]], %[[VAL_11]], %[[BUF]] : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>> // CHECK: } // CHECK: } // CHECK: %[[VAL_20:.*]] = sparse_tensor.load %[[VAL_7]] hasInserts : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>> @@ -259,7 +259,7 @@ // CHECK: scf.yield %[[VAL_94]], %[[VAL_97]], %[[VAL_98:.*]] : index, index, i32 // CHECK: } // CHECK: memref.store %[[VAL_70]]#2, %[[BUF]][] : memref -// CHECK: sparse_tensor.lex_insert %[[VAL_8]], %[[VAL_23]], %[[BUF]] : tensor, memref, memref +// CHECK: sparse_tensor.insert %[[VAL_8]], %[[VAL_23]], %[[BUF]] : tensor, memref, memref // CHECK: } else { // CHECK: } // CHECK: %[[VAL_100:.*]] = arith.cmpi eq, %[[VAL_57]], %[[VAL_60]] : index diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir @@ -166,7 +166,7 @@ // CHECK: scf.yield %[[VAL_37]] : f64 // CHECK: } // CHECK: memref.store %[[VAL_30:.*]], %[[VAL_19]][] : memref -// CHECK: sparse_tensor.lex_insert %[[VAL_10]], %[[VAL_18]], %[[VAL_19]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>, memref, memref +// CHECK: sparse_tensor.insert %[[VAL_10]], %[[VAL_18]], %[[VAL_19]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>, memref, memref // CHECK: } // CHECK: } // CHECK: %[[VAL_39:.*]] = sparse_tensor.load %[[VAL_10]] hasInserts : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> diff --git a/mlir/test/Dialect/SparseTensor/sparse_transpose.mlir b/mlir/test/Dialect/SparseTensor/sparse_transpose.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_transpose.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_transpose.mlir @@ -42,7 +42,7 @@ // CHECK: memref.store %[[VAL_21]], %[[VAL_11]]{{\[}}%[[VAL_2]]] : memref // CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_20]]] : memref // CHECK: memref.store %[[VAL_22]], %[[VAL_12]][] : memref -// CHECK: sparse_tensor.lex_insert %[[VAL_4]], %[[VAL_11]], %[[VAL_12]] : tensor<4x3xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>, memref, memref +// CHECK: sparse_tensor.insert %[[VAL_4]], %[[VAL_11]], %[[VAL_12]] : tensor<4x3xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>, memref, memref // CHECK: } // CHECK: } // CHECK: %[[VAL_23:.*]] = sparse_tensor.load %[[VAL_4]] hasInserts : tensor<4x3xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>