diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -190,20 +190,21 @@ //===----------------------------------------------------------------------===// // Sparse Tensor Management Operations. These operations are "impure" in the -// sense that they do not properly operate on SSA values. Instead, the behavior -// is solely defined by side-effects. These operations provide a bridge between -// "sparsification" on one hand and a support library or actual code generation -// on the other hand. The semantics of these operations may be refined over time -// as our sparse abstractions evolve. +// sense that some behavior is defined by side-effects. These operations provide +// a bridge between "sparsification" on one hand and a support library or actual +// code generation on the other hand. The semantics of these operations may be +// refined over time as our sparse abstractions evolve. //===----------------------------------------------------------------------===// def SparseTensor_InsertOp : SparseTensor_Op<"insert", [TypesMatchWith<"value type matches element type of tensor", "tensor", "value", - "$_self.cast().getElementType()">]>, + "$_self.cast().getElementType()">, + AllTypesMatch<["tensor", "result"]>]>, Arguments<(ins AnyType:$value, AnySparseTensor:$tensor, - Variadic:$indices)> { + Variadic:$indices)>, + Results<(outs AnySparseTensor:$result)> { string summary = "Inserts a value into given sparse tensor"; string description = [{ Inserts the given value at given indices into the underlying @@ -221,19 +222,19 @@ different insertion regimens. Inserting in a way contrary to these properties results in undefined behavior. - Note that this operation is "impure" in the sense that its behavior - is solely defined by side-effects and not SSA values. The semantics - may be refined over time as our sparse abstractions evolve. In - particular, this operation is scheduled to be unified with the - dense counterpart `tensor.insert` that has proper SSA semantics. + Note that this operation is "impure" in the sense that even though + the result is modeled through an SSA value, the insertion is eventually + done "in place", and referencing the old SSA value is undefined behavior. + This operation is scheduled to be unified with the dense counterpart + `tensor.insert` that has pure SSA semantics. Example: ```mlir - sparse_tensor.insert %val into %tensor[%i,%j] : tensor<1024x1024xf64, #CSR> + %result = sparse_tensor.insert %val into %tensor[%i,%j] : tensor<1024x1024xf64, #CSR> ``` }]; - let assemblyFormat = "$value `into` $tensor `[` $indices `]` attr-dict`:` type($tensor)"; + let assemblyFormat = "$value `into` $tensor `[` $indices `]` attr-dict `:` type($tensor)"; let hasVerifier = 1; } @@ -255,7 +256,8 @@ the code for capacity check and reallocation. The typical usage will be for "dynamic" sparse tensors for which a capacity can be set beforehand. - The operation returns an SSA value for the memref. Referencing the memref + Note that this operation is "impure" in the sense that even though + the result is modeled through an SSA value, referencing the memref through the old SSA value after this operation is undefined behavior. Example: @@ -302,9 +304,9 @@ through an indirection using the added array, so that the operations are kept proportional to the number of nonzeros. - Note that this operation is "impure" in the sense that its behavior - is solely defined by side-effects and not SSA values. The semantics - may be refined over time as our sparse abstractions evolve. + Note that this operation is "impure" in the sense that even though the + results are modeled through SSA values, the operation relies on a proper + side-effecting context that sets and resets the expanded arrays. Example: @@ -317,13 +319,15 @@ " `,` type($filled) `,` type($added)"; } -def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>, +def SparseTensor_CompressOp : SparseTensor_Op<"compress", + [AllTypesMatch<["tensor", "result"]>]>, Arguments<(ins AnyStridedMemRefOfRank<1>:$values, StridedMemRefRankOf<[I1],[1]>:$filled, StridedMemRefRankOf<[Index],[1]>:$added, Index:$count, AnySparseTensor:$tensor, - Variadic:$indices)> { + Variadic:$indices)>, + Results<(outs AnySparseTensor:$result)> { string summary = "Compressed an access pattern for insertion"; string description = [{ Finishes a single access pattern expansion by moving inserted elements @@ -335,14 +339,14 @@ array, so that the operations are kept proportional to the number of nonzeros. See the `sparse_tensor.expand` operation for more details. - Note that this operation is "impure" in the sense that its behavior - is solely defined by side-effects and not SSA values. The semantics - may be refined over time as our sparse abstractions evolve. + Note that this operation is "impure" in the sense that even though + the result is modeled through an SSA value, the insertion is eventually + done "in place", and referencing the old SSA value is undefined behavior. Example: ```mlir - sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i] + %result = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i] : memref, memref, memref, tensor<4x4xf64, #CSR> ``` }]; @@ -372,14 +376,16 @@ sparse storage format needs to be finalized. Otherwise, the operation simply folds away. - Note that this operation is "impure" in the sense that its behavior - is solely defined by side-effects and not SSA values. The semantics - may be refined over time as our sparse abstractions evolve. + Note that this operation is "impure" in the sense that even though + the result is modeled through an SSA value, the operation relies on + a proper context of materializing and inserting the tensor value. - Example: + Examples: ```mlir - %1 = sparse_tensor.load %0 : tensor<8xf64, #SV> + %result = sparse_tensor.load %tensor : tensor<8xf64, #SV> + + %1 = sparse_tensor.load %0 hasInserts : tensor<16x32xf32, #CSR> ``` }]; let assemblyFormat = "$tensor (`hasInserts` $hasInserts^)? attr-dict `:` type($tensor)"; @@ -397,8 +403,7 @@ a buffer defined by a pointer. Note that this operation is "impure" in the sense that its behavior - is solely defined by side-effects and not SSA values. The semantics - may be refined over time as our sparse abstractions evolve. + is solely defined by side-effects and not SSA values. Example: @@ -442,8 +447,7 @@ be used to implement the operator. Note that this operation is "impure" in the sense that its behavior is - solely defined by side-effects and not SSA values. The semantics may be - refined over time as our sparse abstractions evolve. + solely defined by side-effects and not SSA values. Example: diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -35,6 +35,18 @@ // Helper methods. //===----------------------------------------------------------------------===// +/// Returns the "tuple" value of the adapted tensor. +static UnrealizedConversionCastOp getTuple(Value tensor) { + return llvm::cast(tensor.getDefiningOp()); +} + +/// Packs the given values as a "tuple" value. +static Value genTuple(OpBuilder &rewriter, Location loc, Type tp, + ValueRange values) { + return rewriter.create(loc, TypeRange(tp), values) + .getResult(0); +} + /// Flatten a list of operands that may contain sparse tensors. static void flattenOperands(ValueRange operands, SmallVectorImpl &flattened) { @@ -43,14 +55,13 @@ // ==> // memref ..., c, memref ... for (auto operand : operands) { - if (auto cast = - dyn_cast(operand.getDefiningOp()); - cast && getSparseTensorEncoding(cast->getResultTypes()[0])) + if (auto tuple = getTuple(operand); + tuple && getSparseTensorEncoding(tuple->getResultTypes()[0])) // An unrealized_conversion_cast will be inserted by type converter to // inter-mix the gap between 1:N conversion between sparse tensors and // fields. In this case, take the operands in the cast and replace the // sparse tensor output with the flattened type array. - flattened.append(cast.getOperands().begin(), cast.getOperands().end()); + flattened.append(tuple.getOperands().begin(), tuple.getOperands().end()); else flattened.push_back(operand); } @@ -73,8 +84,7 @@ // Any other query can consult the dimSizes array at field 0 using, // accounting for the reordering applied to the sparse storage. - auto tuple = - llvm::cast(adaptedValue.getDefiningOp()); + auto tuple = getTuple(adaptedValue); Value idx = constantIndex(rewriter, loc, toStoredDim(tensorTp, dim)); return rewriter.create(loc, tuple.getInputs().front(), idx) .getResult(); @@ -264,6 +274,54 @@ return forOp; } +/// Creates a pushback op for given field and updates the fields array +/// accordingly. +static void createPushback(OpBuilder &builder, Location loc, + SmallVectorImpl &fields, unsigned field, + Value value) { + assert(field < fields.size()); + fields[field] = + builder.create(loc, fields[field].getType(), fields[1], + fields[field], value, APInt(64, field)); +} + +/// Generates insertion code. +// +// TODO: generalize this for any rank and format currently it is just sparse +// vectors as a proof of concept that we have everything in place! +// +static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp, + SmallVectorImpl &fields, + SmallVectorImpl &indices, Value value) { + unsigned rank = indices.size(); + assert(rtp.getShape().size() == rank); + if (rank != 1 || !isCompressedDim(rtp, 0) || !isUniqueDim(rtp, 0) || + !isOrderedDim(rtp, 0)) + return; // TODO: add codegen + // push_back memSizes pointers-0 0 + // push_back memSizes indices-0 index + // push_back memSizes values value + Value zero = constantIndex(builder, loc, 0); + createPushback(builder, loc, fields, 2, zero); + createPushback(builder, loc, fields, 3, indices[0]); + createPushback(builder, loc, fields, 4, value); +} + +/// Generations insertion finalization code. +// +// TODO: this too only works for the very simple case +// +static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp, + SmallVectorImpl &fields) { + if (rtp.getShape().size() != 1 || !isCompressedDim(rtp, 0) || + !isUniqueDim(rtp, 0) || !isOrderedDim(rtp, 0)) + return; // TODO: add codegen + // push_back memSizes pointers-0 memSizes[2] + Value two = constantIndex(builder, loc, 2); + Value size = builder.create(loc, fields[1], two); + createPushback(builder, loc, fields, 2, size); +} + //===----------------------------------------------------------------------===// // Codegen rules. //===----------------------------------------------------------------------===// @@ -325,12 +383,10 @@ assert(!sparseFlat.empty()); if (sparseFlat.size() > 1) { auto flatSize = sparseFlat.size(); - ValueRange sparseElem(iterator_range( + ValueRange fields(iterator_range( newCall.result_begin() + retOffset, newCall.result_begin() + retOffset + flatSize)); - auto castOp = rewriter.create( - loc, TypeRange({retType}), sparseElem); - castedRet.push_back(castOp.getResult(0)); + castedRet.push_back(genTuple(rewriter, loc, retType, fields)); retOffset += flatSize; } else { // If this is an 1:1 conversion, no need for casting. @@ -404,8 +460,7 @@ Location loc = op.getLoc(); SmallVector fields; createAllocFields(rewriter, loc, resType, adaptor.getOperands(), fields); - rewriter.replaceOpWithNewOp( - op, TypeRange{resType}, fields); + rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields)); return success(); } }; @@ -424,8 +479,7 @@ // Replace the sparse tensor deallocation with field deallocations. Location loc = op.getLoc(); - auto tuple = llvm::cast( - adaptor.getTensor().getDefiningOp()); + auto tuple = getTuple(adaptor.getTensor()); for (auto input : tuple.getInputs()) // Deallocate every buffer used to store the sparse tensor handler. rewriter.create(loc, input); @@ -442,11 +496,15 @@ LogicalResult matchAndRewrite(LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (op.getHasInserts()) { - // Finalize any pending insertions. - // TODO: implement - } - rewriter.replaceOp(op, adaptor.getOperands()); + RankedTensorType srcType = + op.getTensor().getType().cast(); + auto tuple = getTuple(adaptor.getTensor()); + // Prepare fields. + SmallVector fields(tuple.getInputs()); + // Generate optional insertion finalization code. + if (op.getHasInserts()) + genEndInsert(rewriter, op.getLoc(), srcType, fields); + rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), srcType, fields)); return success(); } }; @@ -514,10 +572,14 @@ RankedTensorType dstType = op.getTensor().getType().cast(); Type eltType = dstType.getElementType(); + auto tuple = getTuple(adaptor.getTensor()); Value values = adaptor.getValues(); Value filled = adaptor.getFilled(); Value added = adaptor.getAdded(); Value count = adaptor.getCount(); + // Prepare fields and indices. + SmallVector fields(tuple.getInputs()); + SmallVector indices(adaptor.getIndices()); // If the innermost dimension is ordered, we need to sort the indices // in the "added" array prior to applying the compression. unsigned rank = dstType.getShape().size(); @@ -532,21 +594,20 @@ // for (i = 0; i < count; i++) { // index = added[i]; // value = values[index]; - // - // TODO: insert prev_indices, index, value - // + // insert({prev_indices, index}, value); // values[index] = 0; // filled[index] = false; // } Value i = createFor(rewriter, loc, count).getInductionVar(); Value index = rewriter.create(loc, added, i); - rewriter.create(loc, values, index); - // TODO: insert + Value value = rewriter.create(loc, values, index); + indices.push_back(index); + // TODO: generate yield cycle + genInsert(rewriter, loc, dstType, fields, indices, value); rewriter.create(loc, constantZero(rewriter, loc, eltType), values, index); rewriter.create(loc, constantI1(rewriter, loc, false), filled, index); - // Deallocate the buffers on exit of the full loop nest. Operation *parent = op; for (; isa(parent->getParentOp()) || @@ -559,7 +620,28 @@ rewriter.create(loc, values); rewriter.create(loc, filled); rewriter.create(loc, added); - rewriter.eraseOp(op); + rewriter.replaceOp(op, genTuple(rewriter, loc, dstType, fields)); + return success(); + } +}; + +/// Sparse codegen rule for the insert operator. +class SparseInsertConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(InsertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType dstType = + op.getTensor().getType().cast(); + auto tuple = getTuple(adaptor.getTensor()); + // Prepare fields and indices. + SmallVector fields(tuple.getInputs()); + SmallVector indices(adaptor.getIndices()); + // Generate insertion. + Value value = adaptor.getValue(); + genInsert(rewriter, op->getLoc(), dstType, fields, indices, value); + rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), dstType, fields)); return success(); } }; @@ -576,8 +658,7 @@ // Replace the requested pointer access with corresponding field. // The cast_op is inserted by type converter to intermix 1:N type // conversion. - auto tuple = llvm::cast( - adaptor.getTensor().getDefiningOp()); + auto tuple = getTuple(adaptor.getTensor()); unsigned idx = Base::getIndexForOp(tuple, op); auto fields = tuple.getInputs(); assert(idx < fields.size()); @@ -648,6 +729,7 @@ SparseCastConverter, SparseTensorAllocConverter, SparseTensorDeallocConverter, SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter, - SparseToPointersConverter, SparseToIndicesConverter, - SparseToValuesConverter>(typeConverter, patterns.getContext()); + SparseInsertConverter, SparseToPointersConverter, + SparseToIndicesConverter, SparseToValuesConverter>( + typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -1071,9 +1071,9 @@ constantIndex(rewriter, loc, i)); rewriter.create(loc, adaptor.getValue(), vref); SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; - replaceOpWithFuncCall(rewriter, op, name, {}, - {adaptor.getTensor(), mref, vref}, - EmitCInterface::On); + createFuncCall(rewriter, loc, name, {}, {adaptor.getTensor(), mref, vref}, + EmitCInterface::On); + rewriter.replaceOp(op, adaptor.getTensor()); return success(); } }; @@ -1149,9 +1149,10 @@ rewriter.create(loc, adaptor.getIndices()[i], mref, constantIndex(rewriter, loc, i)); SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)}; - replaceOpWithFuncCall(rewriter, op, name, {}, - {tensor, mref, values, filled, added, count}, - EmitCInterface::On); + createFuncCall(rewriter, loc, name, {}, + {tensor, mref, values, filled, added, count}, + EmitCInterface::On); + rewriter.replaceOp(op, adaptor.getTensor()); // Deallocate the buffers on exit of the loop nest. Operation *parent = op; for (; isa(parent->getParentOp()) || diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -1,5 +1,7 @@ // RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize --cse | FileCheck %s +#SV = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> + #SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], indexBitWidth = 64, @@ -383,10 +385,10 @@ %filled: memref, %added: memref, %count: index, - %i: index) { - sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i] + %i: index) -> tensor<8x8xf64, #CSR> { + %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i] : memref, memref, memref, tensor<8x8xf64, #CSR> - return + return %0 : tensor<8x8xf64, #CSR> } // CHECK-LABEL: func @sparse_compression_unordered( @@ -420,8 +422,30 @@ %filled: memref, %added: memref, %count: index, - %i: index) { - sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i] + %i: index) -> tensor<8x8xf64, #UCSR> { + %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i] : memref, memref, memref, tensor<8x8xf64, #UCSR> - return + return %0 : tensor<8x8xf64, #UCSR> +} + +// CHECK-LABEL: func @sparse_insert( +// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: index, +// CHECK-SAME: %[[A6:.*6]]: f64) +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]] +// CHECK: %[[T1:.*]] = sparse_tensor.push_back %[[A1]], %[[A3]], %[[A5]] +// CHECK: %[[T2:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] +// CHECK: %[[T3:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex> +// CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A1]], %[[T0]], %[[T3]] +// CHECK: return %[[A0]], %[[A1]], %[[T4]], %[[T1]], %[[T2]] : memref<1xindex>, memref<3xindex>, memref, memref, memref +func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> { + %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV> + %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SV> + return %1 : tensor<128xf64, #SV> } diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -288,7 +288,7 @@ // CHECK-LABEL: func @sparse_insert( // CHECK-SAME: %[[A:.*]]: !llvm.ptr, // CHECK-SAME: %[[B:.*]]: index, -// CHECK-SAME: %[[C:.*]]: f32) { +// CHECK-SAME: %[[C:.*]]: f32) -> !llvm.ptr { // CHECK-DAG: %[[M:.*]] = memref.alloca() : memref<1xindex> // CHECK-DAG: %[[V:.*]] = memref.alloca() : memref // CHECK-DAG: %[[MC:.*]] = memref.cast %[[M]] : memref<1xindex> to memref @@ -296,12 +296,12 @@ // CHECK-DAG: memref.store %[[B]], %[[M]][%[[C0]]] : memref<1xindex> // CHECK-DAG: memref.store %[[C]], %[[V]][] : memref // CHECK: call @lexInsertF32(%[[A]], %[[MC]], %[[V]]) : (!llvm.ptr, memref, memref) -> () -// CHECK: return +// CHECK: return %[[A]] : !llvm.ptr func.func @sparse_insert(%arg0: tensor<128xf32, #SparseVector>, %arg1: index, - %arg2: f32) { - sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf32, #SparseVector> - return + %arg2: f32) -> tensor<128xf32, #SparseVector> { + %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf32, #SparseVector> + return %0 : tensor<128xf32, #SparseVector> } // CHECK-LABEL: func @sparse_expansion1() @@ -359,7 +359,7 @@ // CHECK-SAME: %[[C:.*2]]: memref, // CHECK-SAME: %[[D:.*3]]: memref, // CHECK-SAME: %[[E:.*4]]: index, -// CHECK-SAME: %[[F:.*5]]: index) +// CHECK-SAME: %[[F:.*5]]: index) -> !llvm.ptr { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[X:.*]] = memref.alloca() : memref<2xindex> // CHECK-DAG: %[[Y:.*]] = memref.cast %[[X]] : memref<2xindex> to memref @@ -368,16 +368,16 @@ // CHECK-DAG: memref.dealloc %[[B]] : memref // CHECK-DAG: memref.dealloc %[[C]] : memref // CHECK-DAG: memref.dealloc %[[D]] : memref -// CHECK: return +// CHECK: return %[[A]] : !llvm.ptr func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>, %values: memref, %filled: memref, %added: memref, %count: index, - %i: index) { - sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i] + %i: index) -> tensor<8x8xf64, #CSR> { + %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i] : memref, memref, memref, tensor<8x8xf64, #CSR> - return + return %0 : tensor<8x8xf64, #CSR> } // CHECK-LABEL: func @sparse_out1( diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -122,12 +122,12 @@ // CHECK-LABEL: func @sparse_insert( // CHECK-SAME: %[[A:.*]]: tensor<128xf64, #sparse_tensor.encoding<{{.*}}>>, // CHECK-SAME: %[[B:.*]]: index, -// CHECK-SAME: %[[C:.*]]: f64) { -// CHECK: sparse_tensor.insert %[[C]] into %[[A]][%[[B]]] : tensor<128xf64, #{{.*}}> -// CHECK: return -func.func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) { - sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector> - return +// CHECK-SAME: %[[C:.*]]: f64) +// CHECK: %[[T:.*]] = sparse_tensor.insert %[[C]] into %[[A]][%[[B]]] : tensor<128xf64, #{{.*}}> +// CHECK: return %[[T]] : tensor<128xf64, #{{.*}}> +func.func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> { + %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector> + return %0 : tensor<128xf64, #SparseVector> } // ----- @@ -181,17 +181,17 @@ // CHECK-SAME: %[[A3:.*3]]: index // CHECK-SAME: %[[A4:.*4]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>, // CHECK-SAME: %[[A5:.*5]]: index) -// CHECK: sparse_tensor.compress %[[A0]], %[[A1]], %[[A2]], %[[A3]] into %[[A4]][%[[A5]] -// CHECK: return +// CHECK: %[[T:.*]] = sparse_tensor.compress %[[A0]], %[[A1]], %[[A2]], %[[A3]] into %[[A4]][%[[A5]] +// CHECK: return %[[T]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> func.func @sparse_compression(%values: memref, %filled: memref, %added: memref, %count: index, %tensor: tensor<8x8xf64, #SparseMatrix>, - %index: index) { - sparse_tensor.compress %values, %filled, %added, %count into %tensor[%index] + %index: index) -> tensor<8x8xf64, #SparseMatrix> { + %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%index] : memref, memref, memref, tensor<8x8xf64, #SparseMatrix> - return + return %0 : tensor<8x8xf64, #SparseMatrix> } // -----