diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -78,6 +78,11 @@ // Misc code generators and utilities. //===----------------------------------------------------------------------===// +template +inline RankedTensorType getRankedTensorType(T t) { + return t.getType().template cast(); +} + /// Generates a 1-valued attribute of the given type. This supports /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`, /// for unsupported types we raise `llvm_unreachable` rather than diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -565,7 +565,7 @@ Value sparse_tensor::genToPointers(OpBuilder &builder, Location loc, Value tensor, uint64_t d) { - RankedTensorType srcTp = tensor.getType().cast(); + RankedTensorType srcTp = getRankedTensorType(tensor); SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp); Type ptrTp = get1DMemRefType(getPointerOverheadType(builder, encSrc), /*withLayout=*/false); @@ -575,7 +575,7 @@ Value sparse_tensor::genToIndices(OpBuilder &builder, Location loc, Value tensor, uint64_t d, uint64_t cooStart) { - RankedTensorType srcTp = tensor.getType().cast(); + RankedTensorType srcTp = getRankedTensorType(tensor); SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp); Type indTp = get1DMemRefType(getIndexOverheadType(builder, encSrc), /*withLayout=*/d >= cooStart); @@ -585,7 +585,7 @@ Value sparse_tensor::genToValues(OpBuilder &builder, Location loc, Value tensor) { - RankedTensorType srcTp = tensor.getType().cast(); + RankedTensorType srcTp = getRankedTensorType(tensor); Type valTp = get1DMemRefType(srcTp.getElementType(), /*withLayout=*/false); return builder.create(loc, valTp, tensor); @@ -596,4 +596,4 @@ SmallVector fields; auto desc = getMutDescriptorFromTensorTuple(tensor, fields); return desc.getValMemSize(builder, loc); -} \ No newline at end of file +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -82,7 +82,7 @@ // a scalar or 0-dimension tensors if (isZeroRankedTensorOrScalar(t.getType())) continue; - auto rtp = t.getType().cast(); + auto rtp = getRankedTensorType(t); auto rank = static_cast(rtp.getRank()); auto enc = getSparseTensorEncoding(rtp); // We always treat sparse output tensor as dense so that we always iterate diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -768,8 +768,7 @@ return failure(); Location loc = op->getLoc(); auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); - RankedTensorType srcType = - op.getTensor().getType().cast(); + auto srcType = getRankedTensorType(op.getTensor()); Type eltType = srcType.getElementType(); Type boolType = rewriter.getIntegerType(1); Type idxType = rewriter.getIndexType(); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -268,7 +268,7 @@ !isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) || !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse()) return failure(); - auto outputType = op.getResult(0).getType().cast(); + auto outputType = getRankedTensorType(op.getResult(0)); // Yielding zero on newly allocated (all-zero) sparse tensors can be // optimized out directly (regardless of dynamic or static size). if (getSparseTensorEncoding(outputType)) { @@ -405,8 +405,8 @@ PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value srcTensor = op.getSrc(); - auto srcTp = srcTensor.getType().template cast(); - auto dstTp = op.getResult().getType().template cast(); + auto srcTp = getRankedTensorType(srcTensor); + auto dstTp = getRankedTensorType(op.getResult()); SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp); SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); if (!encDst || !encSrc) { @@ -483,8 +483,7 @@ return failure(); } if (encSrc) { - RankedTensorType rtp = - op.getSrc().getType().template cast(); + auto rtp = getRankedTensorType(op.getSrc()); auto denseTp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); auto convert = rewriter.create(loc, denseTp, op.getSrc()); @@ -492,8 +491,7 @@ return success(); } if (encDst) { - RankedTensorType rtp = - op.getResult().getType().template cast(); + auto rtp = getRankedTensorType(op.getResult()); auto denseTp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); auto reshape = rewriter.create(loc, denseTp, op.getSrc(), @@ -511,7 +509,7 @@ LogicalResult matchAndRewrite(ConcatenateOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto dstTp = op.getType().cast(); + auto dstTp = getRankedTensorType(op); uint64_t conDim = op.getDimension().getZExtValue(); SmallVector sizes; concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim); @@ -547,7 +545,7 @@ // CSC matrices along column). if (!allDense && conDim == 0 && encDst.hasIdDimOrdering()) { for (auto i : op.getInputs()) { - auto rtp = i.getType().cast(); + auto rtp = getRankedTensorType(i); auto srcEnc = getSparseTensorEncoding(rtp); if (isAllDimOrdered(rtp) && (!srcEnc || srcEnc.hasIdDimOrdering())) { allOrdered = true; @@ -623,7 +621,7 @@ // Accumulates the offset. Note that only static-shaped inputs are allowed // by concatenate op verifier, which saves us from computing the offset // dynamically. - int64_t d = input.getType().cast().getShape()[conDim]; + int64_t d = getRankedTensorType(input).getShape()[conDim]; assert(!ShapedType::isDynamic(d)); offset = rewriter.create(loc, offset, constantIndex(rewriter, loc, d)); @@ -699,7 +697,7 @@ PatternRewriter &rewriter) const { Location loc = op.getLoc(); Value src = op.getSource(); - RankedTensorType dstTp = op.getType().cast(); + auto dstTp = getRankedTensorType(op); SmallVector sizes; sizesFromSrc(rewriter, sizes, loc, src); SmallVector dynSizes; @@ -769,9 +767,9 @@ LogicalResult sparse2DenseRewrite(ConvertOp op, PatternRewriter &rewriter) const { Location loc = op->getLoc(); - RankedTensorType dstTp = op.getType().cast(); + RankedTensorType dstTp = getRankedTensorType(op); Value src = op.getSource(); - RankedTensorType srcTp = src.getType().cast(); + RankedTensorType srcTp = getRankedTensorType(src); SmallVector sizes; sizesForTensor(rewriter, sizes, loc, srcTp, src); @@ -808,8 +806,8 @@ PatternRewriter &rewriter) const { Location loc = op->getLoc(); Value src = op.getSource(); - RankedTensorType srcTp = src.getType().cast(); - RankedTensorType dstTp = op.getType().cast(); + RankedTensorType srcTp = getRankedTensorType(src); + RankedTensorType dstTp = getRankedTensorType(op); SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); int64_t rank = dstTp.getRank(); @@ -928,7 +926,7 @@ auto loc = op.getLoc(); Value input = op.getTensor(); SmallVector reduc = op.getInitArgs(); - auto rtp = input.getType().cast(); + auto rtp = getRankedTensorType(input); int64_t rank = rtp.getRank(); // Special-case: for each over a sparse constant uses its own rewriting @@ -1015,7 +1013,7 @@ LogicalResult matchAndRewrite(NewOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto dstTp = op.getResult().getType().template cast(); + auto dstTp = getRankedTensorType(op.getResult()); SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); if (!encDst) return failure(); @@ -1138,7 +1136,7 @@ Value nnz = rewriter.create(loc, src); // Allocate a temporary buffer for storing dimension sizes and indices. - auto srcTp = src.getType().template cast(); + auto srcTp = getRankedTensorType(src); uint64_t rank = srcTp.getRank(); Type indexTp = rewriter.getIndexType(); Value dimSizes = genAlloca(rewriter, loc, rank, indexTp); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1589,7 +1589,7 @@ // TODO: investigate fusing the conversion with computation, // especially if it is a direct yield! // - auto srcTp = tval.getType().cast(); + auto srcTp = getRankedTensorType(tval); auto dstEnc = SparseTensorEncodingAttr::get( getContext(), srcEnc.getDimLevelType(), permute(env, env.op().getMatchingIndexingMap(t)), // new order