diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -361,22 +361,21 @@ /// Fills the per-dimension sparsity information for all tensors. static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { bool annotated = false; - unsigned numTensors = op.getNumShapedOperands(); - unsigned lhs = numTensors - 1; - for (unsigned t = 0; t < numTensors; t++) { - auto map = op.getIndexingMap(t); + OpOperand *lhs = op.getOutputOperand(0); + for (OpOperand *t : op.getInputAndOutputOperands()) { + auto map = op.getTiedIndexingMap(t); if (!map.isProjectedPermutation()) return false; - auto enc = getSparseTensorEncoding(op.getShapedType(t)); + auto enc = getSparseTensorEncoding(t->get().getType()); if (enc) { annotated = true; if (t == lhs) return false; // TODO: handle sparse outputs } - assert(map.getNumResults() == op.getShapedType(t).getRank()); + assert(map.getNumResults() == op.getRank(t)); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { unsigned idx = map.getDimPosition(perm(enc, d)); - merger.setDim(t, idx, toDim(enc, d)); + merger.setDim(t->getOperandNumber(), idx, toDim(enc, d)); } } return annotated; @@ -414,10 +413,9 @@ std::vector> adjM(n, std::vector(n, false)); // Iterate over the indexing maps of every tensor in the tensor expression. - unsigned numTensors = op.getNumShapedOperands(); - for (unsigned t = 0; t < numTensors; t++) { - auto map = op.getIndexingMap(t); - auto enc = getSparseTensorEncoding(op.getShapedType(t)); + for (OpOperand *t : op.getInputAndOutputOperands()) { + auto map = op.getTiedIndexingMap(t); + auto enc = getSparseTensorEncoding(t->get().getType()); assert(map.getNumDims() == n); // Skip dense tensor constraints when sparse only is requested. if (sparseOnly && !enc) @@ -495,8 +493,8 @@ // set to the undefined index in that dimension. An invariant expression // is set to a synthetic tensor with undefined indices only. unsigned s = merger.addSet(); - unsigned t = - kind == Kind::kTensor ? merger.exp(exp).e0 : op.getNumShapedOperands(); + unsigned t = kind == Kind::kTensor ? merger.exp(exp).e0 + : op.getNumInputsAndOutputs(); merger.set(s).push_back(merger.addLat(t, idx, exp)); return s; } @@ -538,7 +536,7 @@ linalg::GenericOp op, MemRefType denseTp, ArrayRef args) { Location loc = op.getLoc(); - Value tensor = op.getOutput(0); + Value tensor = op.getOutputOperand(0)->get(); // The output tensor simply could materialize from the buffer that will // be generated for the tensor present in the outs() clause. This has // the major advantage that the sparse kernel only updates the nonzero @@ -561,24 +559,21 @@ static void genBuffers(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op) { Location loc = op.getLoc(); - unsigned numTensors = op.getNumShapedOperands(); - unsigned numInputs = op.getNumInputs(); - assert(numTensors == numInputs + 1); + assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); // For every tensor, find lower and upper bound on dimensions, set the // same bounds on loop indices, and obtain dense or sparse buffer(s). SmallVector args; - for (unsigned t = 0; t < numTensors; t++) { - Value tensor = t < numInputs ? op.getInput(t) : op.getOutput(0); - auto tensorType = op.getShapedType(t); - auto shape = tensorType.getShape(); - auto map = op.getIndexingMap(t); - auto enc = getSparseTensorEncoding(tensorType); + for (OpOperand *t : op.getInputAndOutputOperands()) { + Type elementType = getElementTypeOrSelf(t->get().getType()); + auto shape = op.getShape(t); + auto map = op.getTiedIndexingMap(t); + auto enc = getSparseTensorEncoding(t->get().getType()); // Scan all dimensions of current tensor. args.clear(); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { unsigned idx = map.getDimPosition(perm(enc, d)); // Handle sparse storage schemes. - if (merger.isDim(t, idx, Dim::kSparse)) { + if (merger.isDim(t->getOperandNumber(), idx, Dim::kSparse)) { auto dynShape = {ShapedType::kDynamicSize}; auto ptrTp = MemRefType::get( dynShape, genIntType(rewriter, enc.getPointerBitWidth())); @@ -586,36 +581,37 @@ dynShape, genIntType(rewriter, enc.getIndexBitWidth())); Value dim = rewriter.create(loc, d); // Generate sparse primitives to obtains pointer and indices. - codegen.pointers[t][idx] = - rewriter.create(loc, ptrTp, tensor, dim); - codegen.indices[t][idx] = - rewriter.create(loc, indTp, tensor, dim); + codegen.pointers[t->getOperandNumber()][idx] = + rewriter.create(loc, ptrTp, t->get(), dim); + codegen.indices[t->getOperandNumber()][idx] = + rewriter.create(loc, indTp, t->get(), dim); } // Find lower and upper bound in current dimension. Value up; if (shape[d] == MemRefType::kDynamicSize) { - up = rewriter.create(loc, tensor, d); + up = rewriter.create(loc, t->get(), d); args.push_back(up); } else { up = rewriter.create(loc, shape[d]); } - codegen.sizes[idx] = codegen.highs[t][idx] = up; + codegen.sizes[idx] = codegen.highs[t->getOperandNumber()][idx] = up; } // Perform the required bufferization. All dense inputs materialize // from the input tensor. The dense output tensor needs special // handling. Sparse inputs use a sparse primitive to obtain the values. if (!enc) { - auto denseTp = MemRefType::get(shape, tensorType.getElementType()); - if (t < numInputs) - codegen.buffers[t] = - rewriter.create(loc, denseTp, tensor); + auto denseTp = MemRefType::get(shape, elementType); + if (t->getOperandNumber() < op.getNumInputs()) + codegen.buffers[t->getOperandNumber()] = + rewriter.create(loc, denseTp, t->get()); else - codegen.buffers[t] = + codegen.buffers[t->getOperandNumber()] = genOutputBuffer(codegen, rewriter, op, denseTp, args); } else { auto dynShape = {ShapedType::kDynamicSize}; - auto sparseTp = MemRefType::get(dynShape, tensorType.getElementType()); - codegen.buffers[t] = rewriter.create(loc, sparseTp, tensor); + auto sparseTp = MemRefType::get(dynShape, elementType); + codegen.buffers[t->getOperandNumber()] = + rewriter.create(loc, sparseTp, t->get()); } } } @@ -709,19 +705,22 @@ } // Actual load. SmallVector args; - unsigned tensor = merger.exp(exp).e0; - auto map = op.getIndexingMap(tensor); - auto enc = getSparseTensorEncoding(op.getShapedType(tensor)); + OpOperand *tensor = merger.exp(exp).e0 < op.getNumInputs() + ? op.getInputOperand(merger.exp(exp).e0) + : op.getOutputOperand(0); + auto map = op.getTiedIndexingMap(tensor); + auto enc = getSparseTensorEncoding(tensor->get().getType()); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { unsigned idx = map.getDimPosition(perm(enc, d)); args.push_back(codegen.loops[idx]); // universal dense index if (enc) { args.clear(); - args.push_back(codegen.pidxs[tensor][idx]); // position index + args.push_back( + codegen.pidxs[tensor->getOperandNumber()][idx]); // position index } } Location loc = op.getLoc(); - Value ptr = codegen.buffers[tensor]; + Value ptr = codegen.buffers[tensor->getOperandNumber()]; if (codegen.curVecLength > 1) return genVectorLoad(codegen, rewriter, ptr, args); return rewriter.create(loc, ptr, args); @@ -730,10 +729,10 @@ /// Generates a store on a dense tensor. static void genTensorStore(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, - unsigned tensor, Value rhs) { + OpOperand *tensor, Value rhs) { Location loc = op.getLoc(); // Test if this is a scalarized reduction. - unsigned lhs = op.getNumShapedOperands() - 1; + OpOperand *lhs = op.getOutputOperand(0); if (lhs == tensor && codegen.redVal) { if (codegen.curVecLength > 1) rhs = rewriter.create(loc, codegen.curVecMask, rhs, @@ -743,13 +742,13 @@ } // Actual store. SmallVector args; - auto map = op.getIndexingMap(tensor); - assert(!getSparseTensorEncoding(op.getShapedType(tensor))); + auto map = op.getTiedIndexingMap(tensor); + assert(!getSparseTensorEncoding(tensor->get().getType())); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { unsigned idx = map.getDimPosition(d); args.push_back(codegen.loops[idx]); // universal dense index } - Value ptr = codegen.buffers[tensor]; + Value ptr = codegen.buffers[tensor->getOperandNumber()]; if (codegen.curVecLength > 1) genVectorStore(codegen, rewriter, rhs, ptr, args); else @@ -844,7 +843,7 @@ return; assert(codegen.curVecLength == 1); codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain - unsigned lhs = op.getNumShapedOperands() - 1; + OpOperand *lhs = op.getOutputOperand(0); if (auto vtp = red.getType().dyn_cast()) { // TODO: assumes + reductions for now StringAttr kind = rewriter.getStringAttr("add"); @@ -894,9 +893,11 @@ if (merger.exp(exp).kind == Kind::kTensor) { // Inspect tensor indices. bool atLevel = ldx == -1u; - unsigned tensor = merger.exp(exp).e0; - auto map = op.getIndexingMap(tensor); - auto enc = getSparseTensorEncoding(op.getShapedType(tensor)); + OpOperand *tensor = merger.exp(exp).e0 < op.getNumInputs() + ? op.getInputOperand(merger.exp(exp).e0) + : op.getOutputOperand(0); + auto map = op.getTiedIndexingMap(tensor); + auto enc = getSparseTensorEncoding(tensor->get().getType()); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { unsigned idx = map.getDimPosition(perm(enc, d)); if (!codegen.loops[idx]) @@ -905,7 +906,7 @@ atLevel = true; } // All exhausted at this level (atLevel denotes exactly at this level). - unsigned lhs = op.getNumShapedOperands() - 1; + OpOperand *lhs = op.getOutputOperand(0); if (lhs == tensor) { codegen.redExp = hoist ? exp : -1u; } else if (atLevel) { @@ -1006,10 +1007,9 @@ /// TODO: implement strided load/stores on dense arrays static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, unsigned idx) { - unsigned numTensors = op.getNumShapedOperands(); - for (unsigned t = 0; t < numTensors; t++) { - if (!getSparseTensorEncoding(op.getShapedType(t))) { - auto map = op.getIndexingMap(t); + for (OpOperand *t : op.getInputAndOutputOperands()) { + if (!getSparseTensorEncoding(t->get().getType())) { + auto map = op.getTiedIndexingMap(t); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { if (map.getDimPosition(d) == idx && d != rank - 1) return false; @@ -1271,7 +1271,7 @@ unsigned exp, unsigned at) { // At each leaf, assign remaining tensor (sub)expression to output tensor. if (at == topSort.size()) { - unsigned lhs = op.getNumShapedOperands() - 1; + OpOperand *lhs = op.getOutputOperand(0); Value rhs = genExp(merger, codegen, rewriter, op, exp); genTensorStore(merger, codegen, rewriter, op, lhs, rhs); return; @@ -1371,7 +1371,7 @@ // Detects sparse annotations and translate the per-dimension sparsity // information for all tensors to loop indices in the kernel. assert(op.getNumOutputs() == 1); - unsigned numTensors = op.getNumShapedOperands(); + unsigned numTensors = op.getNumInputsAndOutputs(); unsigned numLoops = op.iterator_types().getValue().size(); Merger merger(numTensors, numLoops); if (!findSparseAnnotations(merger, op))