diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -265,11 +265,14 @@ } /// Creates a straightforward counting for-loop. -static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count) { +static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count, + SmallVectorImpl &fields) { Type indexType = builder.getIndexType(); Value zero = constantZero(builder, loc, indexType); Value one = constantOne(builder, loc, indexType); - scf::ForOp forOp = builder.create(loc, zero, count, one); + scf::ForOp forOp = builder.create(loc, zero, count, one, fields); + for (unsigned i = 0, e = fields.size(); i < e; i++) + fields[i] = forOp.getRegionIterArg(i); builder.setInsertionPointToStart(forOp.getBody()); return forOp; } @@ -280,6 +283,9 @@ SmallVectorImpl &fields, unsigned field, Value value) { assert(field < fields.size()); + Type etp = fields[field].getType().cast().getElementType(); + if (value.getType() != etp) + value = builder.create(loc, etp, value); fields[field] = builder.create(loc, fields[field].getType(), fields[1], fields[field], value, APInt(64, field)); @@ -298,11 +304,8 @@ if (rank != 1 || !isCompressedDim(rtp, 0) || !isUniqueDim(rtp, 0) || !isOrderedDim(rtp, 0)) return; // TODO: add codegen - // push_back memSizes pointers-0 0 // push_back memSizes indices-0 index // push_back memSizes values value - Value zero = constantIndex(builder, loc, 0); - createPushback(builder, loc, fields, 2, zero); createPushback(builder, loc, fields, 3, indices[0]); createPushback(builder, loc, fields, 4, value); } @@ -316,9 +319,12 @@ if (rtp.getShape().size() != 1 || !isCompressedDim(rtp, 0) || !isUniqueDim(rtp, 0) || !isOrderedDim(rtp, 0)) return; // TODO: add codegen + // push_back memSizes pointers-0 0 // push_back memSizes pointers-0 memSizes[2] + Value zero = constantIndex(builder, loc, 0); Value two = constantIndex(builder, loc, 2); Value size = builder.create(loc, fields[1], two); + createPushback(builder, loc, fields, 2, zero); createPushback(builder, loc, fields, 2, size); } @@ -460,6 +466,7 @@ Location loc = op.getLoc(); SmallVector fields; createAllocFields(rewriter, loc, resType, adaptor.getOperands(), fields); + // Replace operation with resulting memrefs. rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields)); return success(); } @@ -504,6 +511,7 @@ // Generate optional insertion finalization code. if (op.getHasInserts()) genEndInsert(rewriter, op.getLoc(), srcType, fields); + // Replace operation with resulting memrefs. rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), srcType, fields)); return success(); } @@ -591,23 +599,26 @@ // sparsity of the expanded access pattern. // // Generate - // for (i = 0; i < count; i++) { + // out_memrefs = for (i = 0; i < count; i++)(in_memrefs) { // index = added[i]; // value = values[index]; // insert({prev_indices, index}, value); + // new_memrefs = insert(in_memrefs, {prev_indices, index}, value); // values[index] = 0; // filled[index] = false; + // yield new_memrefs // } - Value i = createFor(rewriter, loc, count).getInductionVar(); + scf::ForOp loop = createFor(rewriter, loc, count, fields); + Value i = loop.getInductionVar(); Value index = rewriter.create(loc, added, i); Value value = rewriter.create(loc, values, index); indices.push_back(index); - // TODO: generate yield cycle genInsert(rewriter, loc, dstType, fields, indices, value); rewriter.create(loc, constantZero(rewriter, loc, eltType), values, index); rewriter.create(loc, constantI1(rewriter, loc, false), filled, index); + rewriter.create(loc, fields); // Deallocate the buffers on exit of the full loop nest. Operation *parent = op; for (; isa(parent->getParentOp()) || @@ -620,7 +631,9 @@ rewriter.create(loc, values); rewriter.create(loc, filled); rewriter.create(loc, added); - rewriter.replaceOp(op, genTuple(rewriter, loc, dstType, fields)); + // Replace operation with resulting memrefs. + rewriter.replaceOp(op, + genTuple(rewriter, loc, dstType, loop->getResults())); return success(); } }; @@ -641,6 +654,7 @@ // Generate insertion. Value value = adaptor.getValue(); genInsert(rewriter, op->getLoc(), dstType, fields, indices, value); + // Replace operation with resulting memrefs. rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), dstType, fields)); return success(); } diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -354,6 +354,49 @@ return %added : memref } +// CHECK-LABEL: func @sparse_compression_1d( +// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: memref, +// CHECK-SAME: %[[A8:.*8]]: index) +// CHECK-DAG: %[[B0:.*]] = arith.constant false +// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref +// CHECK: %[[R:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] iter_args(%[[P0:.*]] = %[[A3]], %[[P1:.*]] = %[[A4]]) -> (memref, memref) { +// CHECK: %[[T1:.*]] = memref.load %[[A7]][%[[I]]] : memref +// CHECK: %[[T2:.*]] = memref.load %[[A5]][%[[T1]]] : memref +// CHECK: %[[T3:.*]] = sparse_tensor.push_back %[[A1]], %[[P0]], %[[T1]] {idx = 3 : index} : memref<3xindex>, memref, index +// CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[T2]] {idx = 4 : index} : memref<3xindex>, memref, f64 +// CHECK: memref.store %[[F0]], %arg5[%[[T1]]] : memref +// CHECK: memref.store %[[B0]], %arg6[%[[T1]]] : memref +// CHECK: scf.yield %[[T3]], %[[T4]] : memref, memref +// CHECK: } +// CHECK: memref.dealloc %[[A5]] : memref +// CHECK: memref.dealloc %[[A6]] : memref +// CHECK: memref.dealloc %[[A7]] : memref +// CHECK: %[[LL:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex> +// CHECK: %[[P1:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]] {idx = 2 : index} : memref<3xindex>, memref, index +// CHECK: %[[P2:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[LL]] {idx = 2 : index} : memref<3xindex>, memref, index +// CHECK: return %[[A0]], %[[A1]], %[[P2]], %[[R]]#0, %[[R]]#1 : memref<1xindex>, memref<3xindex>, memref, memref, memref +func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>, + %values: memref, + %filled: memref, + %added: memref, + %count: index) -> tensor<100xf64, #SV> { + %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[] + : memref, memref, memref, tensor<100xf64, #SV> + %1 = sparse_tensor.load %0 hasInserts : tensor<100xf64, #SV> + return %1 : tensor<100xf64, #SV> +} + // CHECK-LABEL: func @sparse_compression( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, // CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, @@ -372,7 +415,7 @@ // CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] { // CHECK-NEXT: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref -// TODO: insert +// TODO: 2D-insert // CHECK-DAG: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref // CHECK-DAG: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref // CHECK-NEXT: } @@ -388,7 +431,8 @@ %i: index) -> tensor<8x8xf64, #CSR> { %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i] : memref, memref, memref, tensor<8x8xf64, #CSR> - return %0 : tensor<8x8xf64, #CSR> + %1 = sparse_tensor.load %0 hasInserts : tensor<8x8xf64, #CSR> + return %1 : tensor<8x8xf64, #CSR> } // CHECK-LABEL: func @sparse_compression_unordered( @@ -409,7 +453,7 @@ // CHECK-NOT: sparse_tensor.sort // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] { // CHECK-NEXT: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref -// TODO: insert +// TODO: 2D-insert // CHECK-DAG: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref // CHECK-DAG: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref // CHECK-NEXT: } @@ -425,7 +469,8 @@ %i: index) -> tensor<8x8xf64, #UCSR> { %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i] : memref, memref, memref, tensor<8x8xf64, #UCSR> - return %0 : tensor<8x8xf64, #UCSR> + %1 = sparse_tensor.load %0 hasInserts : tensor<8x8xf64, #UCSR> + return %1 : tensor<8x8xf64, #UCSR> } // CHECK-LABEL: func @sparse_insert( @@ -438,10 +483,10 @@ // CHECK-SAME: %[[A6:.*6]]: f64) // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]] // CHECK: %[[T1:.*]] = sparse_tensor.push_back %[[A1]], %[[A3]], %[[A5]] // CHECK: %[[T2:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] // CHECK: %[[T3:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex> +// CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]] // CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A1]], %[[T0]], %[[T3]] // CHECK: return %[[A0]], %[[A1]], %[[T4]], %[[T1]], %[[T2]] : memref<1xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> { @@ -449,3 +494,27 @@ %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SV> return %1 : tensor<128xf64, #SV> } + +// CHECK-LABEL: func @sparse_insert_typed( +// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: index, +// CHECK-SAME: %[[A6:.*6]]: f64) +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[S1:.*]] = arith.index_cast %[[A5]] : index to i64 +// CHECK: %[[T1:.*]] = sparse_tensor.push_back %[[A1]], %[[A3]], %[[S1]] +// CHECK: %[[T2:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] +// CHECK: %[[T3:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex> +// CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]] +// CHECK: %[[S2:.*]] = arith.index_cast %[[T3]] : index to i32 +// CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A1]], %[[T0]], %[[S2]] +// CHECK: return %[[A0]], %[[A1]], %[[T4]], %[[T1]], %[[T2]] : memref<1xindex>, memref<3xindex>, memref, memref, memref +func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> { + %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector> + %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SparseVector> + return %1 : tensor<128xf64, #SparseVector> +}