diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -21,6 +21,7 @@ #include "mlir/IR/Builders.h" namespace mlir { + class Location; class Type; class Value; @@ -151,7 +152,7 @@ //===----------------------------------------------------------------------===// // ExecutionEngine/SparseTensorUtils helper functions. //===----------------------------------------------------------------------===// -// + /// Converts an overhead storage bitwidth to its internal type-encoding. OverheadType overheadTypeEncoding(unsigned width); @@ -194,10 +195,7 @@ DimLevelType dimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt); //===----------------------------------------------------------------------===// -// Misc code generators. -// -// TODO: both of these should move upstream to their respective classes. -// Once RFCs have been created for those changes, list them here. +// Misc code generators and utilities. //===----------------------------------------------------------------------===// /// Generates a 1-valued attribute of the given type. This supports @@ -211,8 +209,24 @@ /// true if `v` is NaN). Value genIsNonzero(OpBuilder &builder, Location loc, Value v); +/// Computes the shape of destination tensor of a reshape operator. This is only +/// used when operands have dynamic shape. The shape of the destination is +/// stored into dstShape. +void genReshapeDstShape(Location loc, PatternRewriter &rewriter, + SmallVector &dstShape, + ArrayRef srcShape, + ArrayRef staticDstShape, + ArrayRef reassociation); + +/// Translate indices during a reshaping operation. +void translateIndicesArray(OpBuilder &builder, Location loc, + ArrayRef reassociation, + ValueRange srcIndices, ArrayRef srcShape, + ArrayRef dstShape, + SmallVectorImpl &dstIndices); + //===----------------------------------------------------------------------===// -// Constant generators. +// Inlined constant generators. // // All these functions are just wrappers to improve code legibility; // therefore, we mark them as `inline` to avoid introducing any additional @@ -315,21 +329,6 @@ static_cast(dimLevelTypeEncoding(dlt))); } -/// Computes the shape of destination tensor of a reshape operator. This is only -/// used when operands have dynamic shape. The shape of the destination is -/// stored into dstShape. -void genReshapeDstShape(Location loc, PatternRewriter &rewriter, - SmallVector &dstShape, - ArrayRef srcShape, - ArrayRef staticDstShape, - ArrayRef reassociation); - -/// Helper method to translate indices during a reshaping operation. -void translateIndicesArray(OpBuilder &builder, Location loc, - ArrayRef reassociation, - ValueRange srcIndices, ArrayRef srcShape, - ArrayRef dstShape, - SmallVectorImpl &dstIndices); } // namespace sparse_tensor } // namespace mlir