diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -35,6 +35,11 @@ // Helper methods. //===----------------------------------------------------------------------===// +/// Returns the tuple value of the adapted tensor. +static UnrealizedConversionCastOp getTuple(Value tensor) { + return llvm::cast(tensor.getDefiningOp()); +} + /// Flatten a list of operands that may contain sparse tensors. static void flattenOperands(ValueRange operands, SmallVectorImpl &flattened) { @@ -43,14 +48,13 @@ // ==> // memref ..., c, memref ... for (auto operand : operands) { - if (auto cast = - dyn_cast(operand.getDefiningOp()); - cast && getSparseTensorEncoding(cast->getResultTypes()[0])) + if (auto tuple = getTuple(operand); + tuple && getSparseTensorEncoding(tuple->getResultTypes()[0])) // An unrealized_conversion_cast will be inserted by type converter to // inter-mix the gap between 1:N conversion between sparse tensors and // fields. In this case, take the operands in the cast and replace the // sparse tensor output with the flattened type array. - flattened.append(cast.getOperands().begin(), cast.getOperands().end()); + flattened.append(tuple.getOperands().begin(), tuple.getOperands().end()); else flattened.push_back(operand); } @@ -73,8 +77,7 @@ // Any other query can consult the dimSizes array at field 0 using, // accounting for the reordering applied to the sparse storage. - auto tuple = - llvm::cast(adaptedValue.getDefiningOp()); + auto tuple = getTuple(adaptedValue); Value idx = constantIndex(rewriter, loc, toStoredDim(tensorTp, dim)); return rewriter.create(loc, tuple.getInputs().front(), idx) .getResult(); @@ -264,6 +267,54 @@ return forOp; } +/// Creates a pushback op. +static Value createPushback(OpBuilder &builder, Location loc, ValueRange fields, + unsigned field, Value value) { + return builder.create(loc, fields[field].getType(), fields[1], + fields[field], value, APInt(64, field)); +} + +/// Generates insertion code. +// +// TODO: generalize this for any rank and format currently it is just sparse +// vectors as a proof of concept that we have everything in place! +// +static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp, + UnrealizedConversionCastOp tuple, + SmallVectorImpl &indices, Value value) { + unsigned rank = indices.size(); + assert(rtp.getShape().size() == rank); + if (rtp.getShape().size() != 1 || !isCompressedDim(rtp, 0) || + !isUniqueDim(rtp, 0) || !isOrderedDim(rtp, 0)) + return; // TODO: add codegen + // push_back memSizes pointers-0 0 + // push_back memSizes indices-0 index + // push_back memSizes values value + auto fields = tuple.getInputs(); + Value zero = constantIndex(builder, loc, 0); + createPushback(builder, loc, fields, 2, zero); + createPushback(builder, loc, fields, 3, indices[0]); + createPushback(builder, loc, fields, 4, value); + // TODO: make insert return SSA val and update values! +} + +/// Generations insertion finalization code. +// +// TODO: this too only works for the very simple case +// +static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp, + UnrealizedConversionCastOp tuple) { + if (rtp.getShape().size() != 1 || !isCompressedDim(rtp, 0) || + !isUniqueDim(rtp, 0) || !isOrderedDim(rtp, 0)) + return; // TODO: add codegen + // push_back memSizes pointers-0 memSizes[2] + auto fields = tuple.getInputs(); + Value two = constantIndex(builder, loc, 2); + Value size = builder.create(loc, fields[1], two); + createPushback(builder, loc, fields, 2, size); + // TODO: make compress return SSA val and update values! +} + //===----------------------------------------------------------------------===// // Codegen rules. //===----------------------------------------------------------------------===// @@ -424,8 +475,7 @@ // Replace the sparse tensor deallocation with field deallocations. Location loc = op.getLoc(); - auto tuple = llvm::cast( - adaptor.getTensor().getDefiningOp()); + auto tuple = getTuple(adaptor.getTensor()); for (auto input : tuple.getInputs()) // Deallocate every buffer used to store the sparse tensor handler. rewriter.create(loc, input); @@ -443,8 +493,10 @@ matchAndRewrite(LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (op.getHasInserts()) { - // Finalize any pending insertions. - // TODO: implement + RankedTensorType srcType = + op.getTensor().getType().cast(); + auto tuple = getTuple(adaptor.getTensor()); + genEndInsert(rewriter, op.getLoc(), srcType, tuple); } rewriter.replaceOp(op, adaptor.getOperands()); return success(); @@ -514,10 +566,15 @@ RankedTensorType dstType = op.getTensor().getType().cast(); Type eltType = dstType.getElementType(); + auto tuple = getTuple(adaptor.getTensor()); Value values = adaptor.getValues(); Value filled = adaptor.getFilled(); Value added = adaptor.getAdded(); Value count = adaptor.getCount(); + // Prepare indices. + SmallVector indices; + for (Value idx : adaptor.getIndices()) + indices.push_back(idx); // If the innermost dimension is ordered, we need to sort the indices // in the "added" array prior to applying the compression. unsigned rank = dstType.getShape().size(); @@ -532,21 +589,19 @@ // for (i = 0; i < count; i++) { // index = added[i]; // value = values[index]; - // - // TODO: insert prev_indices, index, value - // + // insert({prev_indices, index}, value); // values[index] = 0; // filled[index] = false; // } Value i = createFor(rewriter, loc, count).getInductionVar(); Value index = rewriter.create(loc, added, i); - rewriter.create(loc, values, index); - // TODO: insert + Value value = rewriter.create(loc, values, index); + indices.push_back(index); + genInsert(rewriter, loc, dstType, tuple, indices, value); rewriter.create(loc, constantZero(rewriter, loc, eltType), values, index); rewriter.create(loc, constantI1(rewriter, loc, false), filled, index); - // Deallocate the buffers on exit of the full loop nest. Operation *parent = op; for (; isa(parent->getParentOp()) || @@ -564,6 +619,28 @@ } }; +/// Sparse codegen rule for the insert operator. +class SparseInsertConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(InsertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType dstType = + op.getTensor().getType().cast(); + auto tuple = getTuple(adaptor.getTensor()); + // Prepare indices. + SmallVector indices; + for (Value idx : adaptor.getIndices()) + indices.push_back(idx); + // Generate insertion. + Value value = adaptor.getValue(); + genInsert(rewriter, op->getLoc(), dstType, tuple, indices, value); + rewriter.eraseOp(op); + return success(); + } +}; + /// Base class for getter-like operations, e.g., to_indices, to_pointers. template class SparseGetterOpConverter : public OpConversionPattern { @@ -576,8 +653,7 @@ // Replace the requested pointer access with corresponding field. // The cast_op is inserted by type converter to intermix 1:N type // conversion. - auto tuple = llvm::cast( - adaptor.getTensor().getDefiningOp()); + auto tuple = getTuple(adaptor.getTensor()); unsigned idx = Base::getIndexForOp(tuple, op); auto fields = tuple.getInputs(); assert(idx < fields.size()); @@ -648,6 +724,7 @@ SparseCastConverter, SparseTensorAllocConverter, SparseTensorDeallocConverter, SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter, - SparseToPointersConverter, SparseToIndicesConverter, - SparseToValuesConverter>(typeConverter, patterns.getContext()); + SparseInsertConverter, SparseToPointersConverter, + SparseToIndicesConverter, SparseToValuesConverter>( + typeConverter, patterns.getContext()); } diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -1,5 +1,7 @@ // RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize --cse | FileCheck %s +#SV = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> + #SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], indexBitWidth = 64, @@ -425,3 +427,26 @@ : memref, memref, memref, tensor<8x8xf64, #UCSR> return } + +// CHECK-LABEL: func @sparse_insert( +// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: index, +// CHECK-SAME: %[[A6:.*6]]: f64) +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// TODO: fix what is passed in and out as SSA value +// CHECK: sparse_tensor.push_back %[[A1]], +// CHECK: sparse_tensor.push_back %[[A1]], +// CHECK: sparse_tensor.push_back %[[A1]], +// CHECK: memref.load %[[A1]][%[[C2]]] : memref<3xindex> +// CHECK: sparse_tensor.push_back %[[A1]], +// TODO: fix what is returned! +func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> { + sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV> + %0 = sparse_tensor.load %arg0 hasInserts : tensor<128xf64, #SV> + return %0 : tensor<128xf64, #SV> +}