diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1049,50 +1049,52 @@ /// Generates a store on a dense or sparse tensor. static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, Value rhs) { - linalg::GenericOp op = env.op(); - Location loc = op.getLoc(); + // Only unary and binary are allowed to return uninitialized rhs + // to indicate missing output. + if (!rhs) { + assert(env.exp(exp).kind == TensorExp::Kind::kUnary || + env.exp(exp).kind == TensorExp::Kind::kBinary); + return; + } // Test if this is a scalarized reduction. if (env.isReduc()) { env.updateReduc(rhs); return; } - // Store during insertion. + // Regular store. + linalg::GenericOp op = env.op(); + Location loc = op.getLoc(); OpOperand *t = op.getDpsInitOperand(0); - if (env.isSparseOutput(t)) { - if (!rhs) { - // Only unary and binary are allowed to return uninitialized rhs - // to indicate missing output. - assert(env.exp(exp).kind == TensorExp::Kind::kUnary || - env.exp(exp).kind == TensorExp::Kind::kBinary); - } else if (env.exp(exp).kind == TensorExp::Kind::kSelect) { - // Select operation insertion. - Value chain = env.getInsertionChain(); - scf::IfOp ifOp = - builder.create(loc, chain.getType(), rhs, /*else=*/true); - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - // Existing value was preserved to be used here. - assert(env.exp(exp).val); - Value v0 = env.exp(exp).val; - genInsertionStore(env, builder, t, v0); - env.merger().clearExprValue(exp); - // Yield modified insertion chain along true branch. - Value mchain = env.getInsertionChain(); - builder.create(op.getLoc(), mchain); - // Yield original insertion chain along false branch. - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - builder.create(loc, chain); - // Done with if statement. - env.updateInsertionChain(ifOp->getResult(0)); - builder.setInsertionPointAfter(ifOp); - } else { - genInsertionStore(env, builder, t, rhs); - } + if (!env.isSparseOutput(t)) { + SmallVector args; + Value ptr = genSubscript(env, builder, t, args); + builder.create(loc, rhs, ptr, args); return; } - // Actual store. - SmallVector args; - Value ptr = genSubscript(env, builder, t, args); - builder.create(loc, rhs, ptr, args); + // Store during sparse insertion. + if (env.exp(exp).kind != TensorExp::Kind::kSelect) { + genInsertionStore(env, builder, t, rhs); + return; + } + // Select operation insertion. + Value chain = env.getInsertionChain(); + scf::IfOp ifOp = + builder.create(loc, chain.getType(), rhs, /*else=*/true); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + // Existing value was preserved to be used here. + assert(env.exp(exp).val); + Value v0 = env.exp(exp).val; + genInsertionStore(env, builder, t, v0); + env.merger().clearExprValue(exp); + // Yield modified insertion chain along true branch. + Value mchain = env.getInsertionChain(); + builder.create(op.getLoc(), mchain); + // Yield original insertion chain along false branch. + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + builder.create(loc, chain); + // Done with if statement. + env.updateInsertionChain(ifOp->getResult(0)); + builder.setInsertionPointAfter(ifOp); } /// Generates an invariant value. diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir @@ -32,14 +32,14 @@ // // Traits for tensor operations. // -#trait_vec_scale = { +#trait_vec = { indexing_maps = [ affine_map<(i) -> (i)>, // a (in) affine_map<(i) -> (i)> // x (out) ], iterator_types = ["parallel"] } -#trait_mat_scale = { +#trait_mat = { indexing_maps = [ affine_map<(i,j) -> (i,j)>, // A (in) affine_map<(i,j) -> (i,j)> // X (out) @@ -49,13 +49,13 @@ module { // Invert the structure of a sparse vector. Present values become missing. - // Missing values are filled with 1 (i32). - func.func @vector_complement(%arga: tensor) -> tensor { + // Missing values are filled with 1 (i32). Output is sparse. + func.func @vector_complement_sparse(%arga: tensor) -> tensor { %c = arith.constant 0 : index %ci1 = arith.constant 1 : i32 %d = tensor.dim %arga, %c : tensor %xv = bufferization.alloc_tensor(%d) : tensor - %0 = linalg.generic #trait_vec_scale + %0 = linalg.generic #trait_vec ins(%arga: tensor) outs(%xv: tensor) { ^bb(%a: f64, %x: i32): @@ -69,13 +69,35 @@ return %0 : tensor } + // Invert the structure of a sparse vector, where missing values are + // filled with 1. For a dense output, the sparse compiler initializes + // the buffer to all zero at all other places. + func.func @vector_complement_dense(%arga: tensor) -> tensor { + %c = arith.constant 0 : index + %d = tensor.dim %arga, %c : tensor + %xv = bufferization.alloc_tensor(%d) : tensor + %0 = linalg.generic #trait_vec + ins(%arga: tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %x: i32): + %1 = sparse_tensor.unary %a : f64 to i32 + present={} + absent={ + %ci1 = arith.constant 1 : i32 + sparse_tensor.yield %ci1 : i32 + } + linalg.yield %1 : i32 + } -> tensor + return %0 : tensor + } + // Negate existing values. Fill missing ones with +1. func.func @vector_negation(%arga: tensor) -> tensor { %c = arith.constant 0 : index %cf1 = arith.constant 1.0 : f64 %d = tensor.dim %arga, %c : tensor %xv = bufferization.alloc_tensor(%d) : tensor - %0 = linalg.generic #trait_vec_scale + %0 = linalg.generic #trait_vec ins(%arga: tensor) outs(%xv: tensor) { ^bb(%a: f64, %x: f64): @@ -98,7 +120,7 @@ %c = arith.constant 0 : index %d = tensor.dim %arga, %c : tensor %xv = bufferization.alloc_tensor(%d) : tensor - %0 = linalg.generic #trait_vec_scale + %0 = linalg.generic #trait_vec ins(%arga: tensor) outs(%xv: tensor) { ^bb(%a: f64, %x: f64): @@ -126,7 +148,7 @@ %d0 = tensor.dim %argx, %c0 : tensor %d1 = tensor.dim %argx, %c1 : tensor %xv = bufferization.alloc_tensor(%d0, %d1) : tensor - %0 = linalg.generic #trait_mat_scale + %0 = linalg.generic #trait_mat ins(%argx: tensor) outs(%xv: tensor) { ^bb(%a: f64, %x: f64): @@ -153,7 +175,7 @@ %d0 = tensor.dim %argx, %c0 : tensor %d1 = tensor.dim %argx, %c1 : tensor %xv = bufferization.alloc_tensor(%d0, %d1) : tensor - %0 = linalg.generic #trait_mat_scale + %0 = linalg.generic #trait_mat ins(%argx: tensor) outs(%xv: tensor) { ^bb(%a: f64, %x: f64): @@ -223,6 +245,7 @@ // Driver method to call and verify vector kernels. func.func @entry() { + %cmu = arith.constant -99 : i32 %c0 = arith.constant 0 : index // Setup sparse vectors. @@ -240,7 +263,7 @@ %sm1 = sparse_tensor.convert %m1 : tensor<4x8xf64> to tensor // Call sparse vector kernels. - %0 = call @vector_complement(%sv1) + %0 = call @vector_complement_sparse(%sv1) : (tensor) -> tensor %1 = call @vector_negation(%sv1) : (tensor) -> tensor @@ -253,6 +276,9 @@ %4 = call @matrix_slice(%sm1) : (tensor) -> tensor + // Call kernel with dense output. + %5 = call @vector_complement_dense(%sv1) : (tensor) -> tensor + // // Verify the results. // @@ -268,6 +294,7 @@ // CHECK-NEXT: ( ( 3, 3, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 3 ), ( 0, 0, 4, 0, 5, 0, 0, 6 ), ( 7, 0, 7, 7, 0, 0, 0, 0 ) ) // CHECK-NEXT: ( 99, 99, 99, 99, 5, 6, 99, 99, 99, 0, 0, 0, 0, 0, 0, 0 ) // CHECK-NEXT: ( ( 99, 99, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 99 ), ( 0, 0, 99, 0, 5, 0, 0, 6 ), ( 99, 0, 99, 99, 0, 0, 0, 0 ) ) + // CHECK-NEXT: ( 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0 ) // call @dump_vec_f64(%sv1) : (tensor) -> () call @dump_vec_i32(%0) : (tensor) -> () @@ -275,6 +302,8 @@ call @dump_vec_f64(%2) : (tensor) -> () call @dump_mat(%3) : (tensor) -> () call @dump_mat(%4) : (tensor) -> () + %v = vector.transfer_read %5[%c0], %cmu: tensor, vector<32xi32> + vector.print %v : vector<32xi32> // Release the resources. bufferization.dealloc_tensor %sv1 : tensor @@ -284,6 +313,7 @@ bufferization.dealloc_tensor %2 : tensor bufferization.dealloc_tensor %3 : tensor bufferization.dealloc_tensor %4 : tensor + bufferization.dealloc_tensor %5 : tensor return } }