diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -64,6 +64,21 @@ /// given dimension (0 <= d < rank). bool isUniqueDim(RankedTensorType type, uint64_t d); +// +// Reordering. +// + +uint64_t toOrigDim(const SparseTensorEncodingAttr &enc, uint64_t d); +uint64_t toStoredDim(const SparseTensorEncodingAttr &enc, uint64_t d); + +/// Convenience method to translate the given stored dimension +/// to the original dimension (0 <= d < rank). +uint64_t toOrigDim(RankedTensorType type, uint64_t d); + +/// Convenience method to translate the given original dimension +/// to the stored dimension (0 <= d < rank). +uint64_t toStoredDim(RankedTensorType type, uint64_t d); + } // namespace sparse_tensor } // namespace mlir diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -319,6 +319,40 @@ return true; // unannotated tensor is dense (and thus unique) } +uint64_t mlir::sparse_tensor::toOrigDim(const SparseTensorEncodingAttr &enc, + uint64_t d) { + if (enc) { + auto order = enc.getDimOrdering(); + if (order) { + assert(order.isPermutation()); + return order.getDimPosition(d); + } + } + return d; +} + +uint64_t mlir::sparse_tensor::toStoredDim(const SparseTensorEncodingAttr &enc, + uint64_t d) { + if (enc) { + auto order = enc.getDimOrdering(); + if (order) { + assert(order.isPermutation()); + return order.getPermutedPosition(d); + } + } + return d; +} + +uint64_t mlir::sparse_tensor::toOrigDim(RankedTensorType type, uint64_t d) { + assert(d < static_cast(type.getRank())); + return toOrigDim(getSparseTensorEncoding(type), d); +} + +uint64_t mlir::sparse_tensor::toStoredDim(RankedTensorType type, uint64_t d) { + assert(d < static_cast(type.getRank())); + return toStoredDim(getSparseTensorEncoding(type), d); +} + //===----------------------------------------------------------------------===// // TensorDialect Operations. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -35,26 +35,6 @@ // Helper methods. //===----------------------------------------------------------------------===// -/// Reorders stored dimension to original dimension. -static unsigned toOrig(const SparseTensorEncodingAttr &enc, unsigned i) { - auto order = enc.getDimOrdering(); - if (order) { - assert(order.isPermutation()); - return order.getDimPosition(i); - } - return i; -} - -/// Reorders original dimension to stored dimension. -static unsigned toStored(const SparseTensorEncodingAttr &enc, unsigned i) { - auto order = enc.getDimOrdering(); - if (order) { - assert(order.isPermutation()); - return order.getPermutedPosition(i); - } - return i; -} - /// Flatten a list of operands that may contain sparse tensors. static void flattenOperands(ValueRange operands, SmallVectorImpl &flattened) { @@ -79,7 +59,7 @@ /// Gets the dimension size for the given sparse tensor at the given dim. /// Returns None if no sparse encoding is attached to the tensor type. static Optional sizeFromTensorAtDim(OpBuilder &rewriter, Location loc, - ShapedType tensorTp, + RankedTensorType tensorTp, Value adaptedValue, unsigned dim) { auto enc = getSparseTensorEncoding(tensorTp); if (!enc) @@ -95,9 +75,8 @@ // accounting for the reordering applied to the sparse storage. auto tuple = llvm::cast(adaptedValue.getDefiningOp()); - return rewriter - .create(loc, tuple.getInputs().front(), - constantIndex(rewriter, loc, toStored(enc, dim))) + Value idx = constantIndex(rewriter, loc, toStoredDim(tensorTp, dim)); + return rewriter.create(loc, tuple.getInputs().front(), idx) .getResult(); } @@ -243,7 +222,7 @@ // Per-dimension storage. for (unsigned r = 0; r < rank; r++) { // Get the original dimension (ro) for the current stored dimension. - unsigned ro = toOrig(enc, r); + unsigned ro = toOrigDim(rType, r); builder.create(loc, sizes[ro], dimSizes, constantIndex(builder, loc, r)); linear = builder.create(loc, linear, sizes[ro]); @@ -490,10 +469,7 @@ // Determine the size for access expansion (always the innermost stored // dimension size, translated back to original dimension). Note that we // recursively rewrite the new DimOp on the **original** tensor. - auto enc = getSparseTensorEncoding(srcType); - unsigned innerDim = srcType.getRank() - 1; - if (AffineMap p = enc.getDimOrdering()) - innerDim = p.getDimPosition(innerDim); + unsigned innerDim = toOrigDim(srcType, srcType.getRank() - 1); auto sz = sizeFromTensorAtDim(rewriter, loc, srcType, adaptor.getTensor(), innerDim); assert(sz); // This for sure is a sparse tensor diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -106,13 +106,11 @@ /// Generates dimension size call. static Value genDimSizeCall(OpBuilder &builder, Location loc, SparseTensorEncodingAttr &enc, Value src, - int64_t idx) { - // Permute the index according to an optional dimension ordering. - if (AffineMap p = enc.getDimOrdering()) - idx = p.getPermutedPosition(idx); + uint64_t idx) { // Generate the call. StringRef name = "sparseDimSize"; - SmallVector params{src, constantIndex(builder, loc, idx)}; + SmallVector params{ + src, constantIndex(builder, loc, toStoredDim(enc, idx))}; Type iTp = builder.getIndexType(); return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off) .getResult(0); @@ -266,13 +264,8 @@ // default, or otherwise the "reverse" permutation of a given ordering, so // that indices can be mapped quickly to the right position. SmallVector rev(sz); - if (AffineMap p = enc.getDimOrdering()) { - for (unsigned i = 0; i < sz; i++) - rev[p.getDimPosition(i)] = constantIndex(builder, loc, i); - } else { - for (unsigned i = 0; i < sz; i++) - rev[i] = constantIndex(builder, loc, i); - } + for (unsigned i = 0; i < sz; i++) + rev[toOrigDim(enc, i)] = constantIndex(builder, loc, i); params.push_back(genBuffer(builder, loc, rev)); // Secondary and primary types encoding. Type elemTp = stp.getElementType(); @@ -1230,7 +1223,8 @@ matchAndRewrite(ExpandOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - ShapedType srcType = op.getTensor().getType().cast(); + RankedTensorType srcType = + op.getTensor().getType().cast(); Type eltType = srcType.getElementType(); Type boolType = rewriter.getIntegerType(1); Type idxType = rewriter.getIndexType(); @@ -1239,9 +1233,7 @@ // Determine the size for access expansion (always the innermost stored // dimension size, translated back to original dimension). auto enc = getSparseTensorEncoding(srcType); - unsigned innerDim = srcType.getRank() - 1; - if (AffineMap p = enc.getDimOrdering()) - innerDim = p.getDimPosition(innerDim); + unsigned innerDim = toOrigDim(srcType, srcType.getRank() - 1); auto sz = sizeFromPtrAtDim(rewriter, loc, enc, srcType, adaptor.getTensor(), innerDim); // Allocate temporary buffers for values, filled-switch, and indices. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -128,18 +128,6 @@ return AffineMap::getPermutationMap(perm, context); } -/// Helper method to apply dimension ordering permutation. -static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) { - if (enc) { - auto order = enc.getDimOrdering(); - if (order) { - assert(order.isPermutation()); - return order.getDimPosition(d); - } - } - return d; -} - /// Helper method to obtain the dimension level format from the encoding. // // TODO: note that we store, but currently completely *ignore* the properties @@ -214,7 +202,7 @@ assert(map.getNumResults() == op.getRank(t)); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { unsigned tensor = t->getOperandNumber(); - AffineExpr a = map.getResult(perm(enc, d)); + AffineExpr a = map.getResult(toOrigDim(enc, d)); if (!findAffine(merger, tensor, a, toDimLevelFormat(enc, d))) return false; // inadmissible affine expression } @@ -319,8 +307,8 @@ // example, the tensor expresion A_ijk forces the ordering i < j < k // on the loop indices if no explicit dimension ordering is given. for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { - AffineExpr f = map.getResult(perm(enc, d - 1)); - AffineExpr t = map.getResult(perm(enc, d)); + AffineExpr f = map.getResult(toOrigDim(enc, d - 1)); + AffineExpr t = map.getResult(toOrigDim(enc, d)); addAffineOrderings(adjM, inDegree, f, t, 0); } // Push unrelated loops into sparse iteration space, so these @@ -359,7 +347,7 @@ /// whether the out tensor in the tensor expression codegen is admissible. /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective /// nesting depth when a "truly dynamic" sparse tensor output occurs. -static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, +static bool isAdmissibleTensorExp(Merger &merger, linalg::GenericOp op, std::vector &topSort, unsigned exp, OpOperand **sparseOut, unsigned &outerParNest) { @@ -367,7 +355,7 @@ unsigned tensor = lhs->getOperandNumber(); auto enc = getSparseTensorEncoding(lhs->get().getType()); // An non-annotated output tensor is assumed dense, and becomes a random - // access n-dim memref. Admissable since insertions cannot occur. + // access n-dim memref. Admissible since insertions cannot occur. if (!enc) return true; // An all-dense annotated "sparse" output tensor becomes a linearized random @@ -559,7 +547,7 @@ // Scan all dimensions of current tensor. args.clear(); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { - AffineExpr a = map.getResult(perm(enc, d)); + AffineExpr a = map.getResult(toOrigDim(enc, d)); if (a.getKind() != AffineExprKind::DimId) continue; // compound unsigned idx = a.cast().getPosition(); @@ -587,7 +575,7 @@ assert(merger.isDimLevelType(tensor, idx, DimLvlType::kDense)); } // Find upper bound in current dimension. - unsigned p = perm(enc, d); + unsigned p = toOrigDim(enc, d); Value up = linalg::createOrFoldDimOp(builder, loc, t->get(), p); if (ShapedType::isDynamic(shape[p])) args.push_back(up); @@ -735,7 +723,7 @@ static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) { auto map = op.getTiedIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); - AffineExpr a = map.getResult(perm(enc, map.getNumResults() - 1)); + AffineExpr a = map.getResult(toOrigDim(enc, map.getNumResults() - 1)); assert(a.getKind() == AffineExprKind::DimId); unsigned idx = a.cast().getPosition(); return codegen.loops[idx]; @@ -752,14 +740,14 @@ if (enc) { // Note that currently, all sparse subscripts are simple. // TODO: accept affine too? - AffineExpr a = map.getResult(perm(enc, rank - 1)); + AffineExpr a = map.getResult(toOrigDim(enc, rank - 1)); assert(a.getKind() == AffineExprKind::DimId); unsigned idx = a.cast().getPosition(); assert(codegen.pidxs[tensor][idx] != nullptr); args.push_back(codegen.pidxs[tensor][idx]); // position index } else { for (unsigned d = 0; d < rank; d++) { - AffineExpr a = map.getResult(perm(enc, d)); + AffineExpr a = map.getResult(d); args.push_back(genAffine(codegen, builder, a, op.getLoc())); } } @@ -1094,7 +1082,7 @@ auto map = op.getTiedIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { - AffineExpr a = map.getResult(perm(enc, d)); + AffineExpr a = map.getResult(toOrigDim(enc, d)); if (!isInvariantAffine(codegen, a, ldx, atLevel)) return; // still in play } @@ -1882,7 +1870,7 @@ for (auto mask : allMask) if (computeIterationGraph(merger, op, topSort, mask)) { hasCycle = false; - if (isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut, + if (isAdmissibleTensorExp(merger, op, topSort, exp, &sparseOut, outerParNest)) { isAdmissible = true; break;