diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -38,6 +38,18 @@ namespace mlir { namespace sparse_tensor { +/// Convenience method to abbreviate casting `getType()`. +template +inline RankedTensorType getRankedTensorType(T t) { + return t.getType().template cast(); +} + +/// Convenience method to abbreviate casting `getType()`. +template +inline MemRefType getMemRefType(T t) { + return t.getType().template cast(); +} + /// Convenience method to get a sparse encoding attribute from a type. /// Returns null-attribute for any type without an encoding. SparseTensorEncodingAttr getSparseTensorEncoding(Type type); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -28,6 +28,15 @@ using namespace mlir; using namespace mlir::sparse_tensor; +//===----------------------------------------------------------------------===// +// Additional convenience methods. +//===----------------------------------------------------------------------===// + +template +static inline int64_t getTypeRank(T t) { + return getRankedTensorType(t).getRank(); +} + //===----------------------------------------------------------------------===// // TensorDialect Attribute Methods. //===----------------------------------------------------------------------===// @@ -525,12 +534,11 @@ //===----------------------------------------------------------------------===// static LogicalResult isInBounds(uint64_t dim, Value tensor) { - return success(dim < - (uint64_t)tensor.getType().cast().getRank()); + return success(dim < static_cast(getTypeRank(tensor))); } static LogicalResult isMatchingWidth(Value result, unsigned width) { - const Type etp = result.getType().cast().getElementType(); + const Type etp = getMemRefType(result).getElementType(); return success(width == 0 ? etp.isIndex() : etp.isInteger(width)); } @@ -562,8 +570,7 @@ } LogicalResult NewOp::verify() { - if (getExpandSymmetry() && - getResult().getType().cast().getRank() != 2) + if (getExpandSymmetry() && getTypeRank(getResult()) != 2) return emitOpError("expand_symmetry can only be used for 2D tensors"); return success(); } @@ -624,8 +631,8 @@ } LogicalResult ToValuesOp::verify() { - RankedTensorType ttp = getTensor().getType().cast(); - MemRefType mtp = getResult().getType().cast(); + auto ttp = getRankedTensorType(getTensor()); + auto mtp = getMemRefType(getResult()); if (ttp.getElementType() != mtp.getElementType()) return emitError("unexpected mismatch in element types"); return success(); @@ -754,7 +761,7 @@ } LogicalResult ConcatenateOp::verify() { - auto dstTp = getType().cast(); + auto dstTp = getRankedTensorType(*this); uint64_t concatDim = getDimension().getZExtValue(); unsigned rank = dstTp.getRank(); @@ -775,8 +782,7 @@ concatDim)); for (size_t i = 0, e = getInputs().size(); i < e; i++) { - Value input = getInputs()[i]; - auto inputRank = input.getType().cast().getRank(); + const auto inputRank = getTypeRank(getInputs()[i]); if (inputRank != rank) return emitError( llvm::formatv("The input tensor ${0} has a different rank (rank={1}) " @@ -785,15 +791,13 @@ } for (unsigned i = 0; i < rank; i++) { - auto dstDim = dstTp.getShape()[i]; + const auto dstDim = dstTp.getShape()[i]; if (i == concatDim) { if (!ShapedType::isDynamic(dstDim)) { + // If we reach here, all inputs should have static shapes. unsigned sumDim = 0; - for (auto src : getInputs()) { - // If we reach here, all inputs should have static shapes. - auto d = src.getType().cast().getShape()[i]; - sumDim += d; - } + for (auto src : getInputs()) + sumDim += getRankedTensorType(src).getShape()[i]; // If all dimension are statically known, the sum of all the input // dimensions should be equal to the output dimension. if (sumDim != dstDim) @@ -804,7 +808,7 @@ } else { int64_t prev = dstDim; for (auto src : getInputs()) { - auto d = src.getType().cast().getShape()[i]; + const auto d = getRankedTensorType(src).getShape()[i]; if (!ShapedType::isDynamic(prev) && d != prev) return emitError("All dimensions (expect for the concatenating one) " "should be equal."); @@ -817,8 +821,7 @@ } LogicalResult InsertOp::verify() { - RankedTensorType ttp = getTensor().getType().cast(); - if (ttp.getRank() != static_cast(getIndices().size())) + if (getTypeRank(getTensor()) != static_cast(getIndices().size())) return emitOpError("incorrect number of indices"); return success(); } @@ -838,8 +841,7 @@ } LogicalResult CompressOp::verify() { - RankedTensorType ttp = getTensor().getType().cast(); - if (ttp.getRank() != 1 + static_cast(getIndices().size())) + if (getTypeRank(getTensor()) != 1 + static_cast(getIndices().size())) return emitOpError("incorrect number of indices"); return success(); } @@ -860,7 +862,7 @@ // Builds foreach body. if (!bodyBuilder) return; - auto rtp = tensor.getType().cast(); + auto rtp = getRankedTensorType(tensor); int64_t rank = rtp.getRank(); SmallVector blockArgTypes; @@ -886,7 +888,7 @@ } LogicalResult ForeachOp::verify() { - auto t = getTensor().getType().cast(); + auto t = getRankedTensorType(getTensor()); auto args = getBody()->getArguments(); if (static_cast(t.getRank()) + 1 + getInitArgs().size() != @@ -944,11 +946,11 @@ auto n = getN().getDefiningOp(); - Type xtp = getXs().front().getType().cast().getElementType(); + Type xtp = getMemRefType(getXs().front()).getElementType(); auto checkTypes = [&](ValueRange operands, bool checkEleType = true) -> LogicalResult { for (Value opnd : operands) { - MemRefType mtp = opnd.getType().cast(); + auto mtp = getMemRefType(opnd); int64_t dim = mtp.getShape()[0]; // We can't check the size of dynamic dimension at compile-time, but all // xs and ys should have a dimension not less than n at runtime. @@ -986,7 +988,7 @@ } auto checkDim = [&](Value v, uint64_t min, const char *message) { - MemRefType tp = v.getType().cast(); + auto tp = getMemRefType(v); int64_t dim = tp.getShape()[0]; if (!ShapedType::isDynamic(dim) && dim < (int64_t)min) { emitError(llvm::formatv("{0} got {1} < {2}", message, dim, min)); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -78,11 +78,6 @@ // Misc code generators and utilities. //===----------------------------------------------------------------------===// -template -inline RankedTensorType getRankedTensorType(T t) { - return t.getType().template cast(); -} - /// Generates a 1-valued attribute of the given type. This supports /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`, /// for unsupported types we raise `llvm_unreachable` rather than diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -558,7 +558,7 @@ idxBuffer = builder.create( loc, MemRefType::get({rank}, builder.getIndexType()), idxBuffer); SmallVector shape(rank, ShapedType::kDynamic); - Type elemTp = valuesBuffer.getType().cast().getElementType(); + Type elemTp = getMemRefType(valuesBuffer).getElementType(); return builder.create(loc, MemRefType::get(shape, elemTp), valuesBuffer, idxBuffer); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -55,14 +55,14 @@ ValueRange operands) { nameOstream << namePrefix << nx << "_" - << operands[xStartIdx].getType().cast().getElementType(); + << getMemRefType(operands[xStartIdx]).getElementType(); if (isCoo) nameOstream << "_coo_" << ny; uint64_t yBufferOffset = isCoo ? 1 : nx; for (Value v : operands.drop_front(xStartIdx + yBufferOffset)) - nameOstream << "_" << v.getType().cast().getElementType(); + nameOstream << "_" << getMemRefType(v).getElementType(); } /// Looks up a function that is appropriate for the given operands being @@ -719,7 +719,7 @@ // Convert `values` to have dynamic shape and append them to `operands`. for (Value v : xys) { - auto mtp = v.getType().cast(); + auto mtp = getMemRefType(v); if (!mtp.isDynamicDim(0)) { auto newMtp = MemRefType::get({ShapedType::kDynamic}, mtp.getElementType()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -505,8 +505,8 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) { Location loc = op.getLoc(); - auto srcTp = op.getSrc().getType().template cast(); - auto dstTp = op.getResult().getType().template cast(); + auto srcTp = getRankedTensorType(op.getSrc()); + auto dstTp = getRankedTensorType(op.getResult()); auto encSrc = getSparseTensorEncoding(srcTp); auto encDst = getSparseTensorEncoding(dstTp); if (!encDst || !encSrc) @@ -888,8 +888,8 @@ matchAndRewrite(ConvertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - Type resType = op.getType(); - Type srcType = op.getSource().getType(); + auto resType = getRankedTensorType(op); + auto srcType = getRankedTensorType(op.getSource()); auto encDst = getSparseTensorEncoding(resType); auto encSrc = getSparseTensorEncoding(srcType); Value src = adaptor.getOperands()[0]; @@ -953,10 +953,8 @@ // dst[elem.indices] = elem.value; // } // delete iter; - RankedTensorType dstTensorTp = resType.cast(); - RankedTensorType srcTensorTp = srcType.cast(); - unsigned rank = dstTensorTp.getRank(); - Type elemTp = dstTensorTp.getElementType(); + const unsigned rank = resType.getRank(); + const Type elemTp = resType.getElementType(); // Fabricate a no-permutation encoding for NewCallParams // The pointer/index types must be those of `src`. // The dimLevelTypes aren't actually used by Action::kToIterator. @@ -965,16 +963,16 @@ SmallVector(rank, DimLevelType::Dense), AffineMap(), AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); SmallVector dimSizes = - getDimSizes(rewriter, loc, encSrc, srcTensorTp, src); + getDimSizes(rewriter, loc, encSrc, srcType, src); Value iter = NewCallParams(rewriter, loc) - .genBuffers(encDst, dimSizes, dstTensorTp) + .genBuffers(encDst, dimSizes, resType) .genNewCall(Action::kToIterator, src); Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); Block *insertionBlock = rewriter.getInsertionBlock(); // TODO: Dense buffers should be allocated/deallocated via the callback // in BufferizationOptions. - Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, dimSizes); + Value dst = allocDenseTensor(rewriter, loc, resType, dimSizes); SmallVector noArgs; SmallVector noTypes; auto whileOp = rewriter.create(loc, noTypes, noArgs); @@ -1192,7 +1190,7 @@ // index order. All values are passed by reference through stack // allocated memrefs. Location loc = op->getLoc(); - auto tp = op.getTensor().getType().cast(); + auto tp = getRankedTensorType(op.getTensor()); auto elemTp = tp.getElementType(); unsigned rank = tp.getRank(); auto mref = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); @@ -1217,8 +1215,7 @@ matchAndRewrite(ExpandOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - RankedTensorType srcType = - op.getTensor().getType().cast(); + auto srcType = getRankedTensorType(op.getTensor()); Type eltType = srcType.getElementType(); Type boolType = rewriter.getIntegerType(1); Type idxType = rewriter.getIndexType(); @@ -1272,7 +1269,7 @@ Value added = adaptor.getAdded(); Value count = adaptor.getCount(); Value tensor = adaptor.getTensor(); - auto tp = op.getTensor().getType().cast(); + auto tp = getRankedTensorType(op.getTensor()); Type elemTp = tp.getElementType(); unsigned rank = tp.getRank(); auto mref = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); @@ -1326,7 +1323,7 @@ // a[ adjustForOffset(elem.indices) ] = elem.value // return a Location loc = op.getLoc(); - auto dstTp = op.getType().cast(); + auto dstTp = getRankedTensorType(op); auto encDst = getSparseTensorEncoding(dstTp); Type elemTp = dstTp.getElementType(); uint64_t concatDim = op.getDimension().getZExtValue(); @@ -1381,7 +1378,7 @@ for (auto it : llvm::zip(op.getInputs(), adaptor.getInputs())) { Value orignalOp = std::get<0>(it); // Input (with encoding) from Op Value adaptedOp = std::get<1>(it); // Input (type converted) from adaptor - RankedTensorType srcTp = orignalOp.getType().cast(); + auto srcTp = getRankedTensorType(orignalOp); auto encSrc = getSparseTensorEncoding(srcTp); if (encSrc) { genSparseCOOIterationLoop( diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -69,7 +69,7 @@ /// Constructs vector type from pointer. static VectorType vectorType(VL vl, Value ptr) { - return vectorType(vl, ptr.getType().cast().getElementType()); + return vectorType(vl, getMemRefType(ptr).getElementType()); } /// Constructs vector iteration mask.