diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/CodegenUtils.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/CodegenUtils.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/CodegenUtils.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/CodegenUtils.h @@ -35,6 +35,9 @@ /// Converts an overhead storage bitwidth to its internal type-encoding. OverheadType overheadTypeEncoding(unsigned width); +/// Converts an overhead storage type to its internal type-encoding. +OverheadType overheadTypeEncoding(Type tp); + /// Converts the internal type-encoding for overhead storage to an mlir::Type. Type getOverheadType(Builder &builder, OverheadType ot); @@ -46,9 +49,21 @@ Type getIndexOverheadType(Builder &builder, const SparseTensorEncodingAttr &enc); +/// Convert OverheadType to its function-name suffix. +StringRef overheadTypeFunctionSuffix(OverheadType ot); + +/// Converts an overhead storage type to its function-name suffix. +StringRef overheadTypeFunctionSuffix(Type overheadTp); + /// Converts a primary storage type to its internal type-encoding. PrimaryType primaryTypeEncoding(Type elemTp); +/// Convert PrimaryType to its function-name suffix. +StringRef primaryTypeFunctionSuffix(PrimaryType pt); + +/// Converts a primary storage type to its function-name suffix. +StringRef primaryTypeFunctionSuffix(Type elemTp); + /// Converts the IR's dimension level type to its internal type-encoding. DimLevelType dimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -231,21 +231,7 @@ Type eltType, Value ptr, Value val, Value ind, Value perm) { Location loc = op->getLoc(); - StringRef name; - if (eltType.isF64()) - name = "addEltF64"; - else if (eltType.isF32()) - name = "addEltF32"; - else if (eltType.isInteger(64)) - name = "addEltI64"; - else if (eltType.isInteger(32)) - name = "addEltI32"; - else if (eltType.isInteger(16)) - name = "addEltI16"; - else if (eltType.isInteger(8)) - name = "addEltI8"; - else - llvm_unreachable("Unknown element type"); + SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)}; SmallVector params; params.push_back(ptr); params.push_back(val); @@ -264,21 +250,7 @@ Value iter, Value ind, Value elemPtr) { Location loc = op->getLoc(); Type elemTp = elemPtr.getType().cast().getElementType(); - StringRef name; - if (elemTp.isF64()) - name = "getNextF64"; - else if (elemTp.isF32()) - name = "getNextF32"; - else if (elemTp.isInteger(64)) - name = "getNextI64"; - else if (elemTp.isInteger(32)) - name = "getNextI32"; - else if (elemTp.isInteger(16)) - name = "getNextI16"; - else if (elemTp.isInteger(8)) - name = "getNextI8"; - else - llvm_unreachable("Unknown element type"); + SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)}; SmallVector params; params.push_back(iter); params.push_back(ind); @@ -649,20 +621,8 @@ matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resType = op.getType(); - Type eltType = resType.cast().getElementType(); - StringRef name; - if (eltType.isIndex()) - name = "sparsePointers"; - else if (eltType.isInteger(64)) - name = "sparsePointers64"; - else if (eltType.isInteger(32)) - name = "sparsePointers32"; - else if (eltType.isInteger(16)) - name = "sparsePointers16"; - else if (eltType.isInteger(8)) - name = "sparsePointers8"; - else - return failure(); + Type ptrType = resType.cast().getElementType(); + SmallString<16> name{"sparsePointers", overheadTypeFunctionSuffix(ptrType)}; auto fn = getFunc(op, name, resType, adaptor.getOperands(), /*emitCInterface=*/true); rewriter.replaceOpWithNewOp(op, resType, fn, adaptor.getOperands()); @@ -678,20 +638,8 @@ matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resType = op.getType(); - Type eltType = resType.cast().getElementType(); - StringRef name; - if (eltType.isIndex()) - name = "sparseIndices"; - else if (eltType.isInteger(64)) - name = "sparseIndices64"; - else if (eltType.isInteger(32)) - name = "sparseIndices32"; - else if (eltType.isInteger(16)) - name = "sparseIndices16"; - else if (eltType.isInteger(8)) - name = "sparseIndices8"; - else - return failure(); + Type indType = resType.cast().getElementType(); + SmallString<15> name{"sparseIndices", overheadTypeFunctionSuffix(indType)}; auto fn = getFunc(op, name, resType, adaptor.getOperands(), /*emitCInterface=*/true); rewriter.replaceOpWithNewOp(op, resType, fn, adaptor.getOperands()); @@ -708,21 +656,7 @@ ConversionPatternRewriter &rewriter) const override { Type resType = op.getType(); Type eltType = resType.cast().getElementType(); - StringRef name; - if (eltType.isF64()) - name = "sparseValuesF64"; - else if (eltType.isF32()) - name = "sparseValuesF32"; - else if (eltType.isInteger(64)) - name = "sparseValuesI64"; - else if (eltType.isInteger(32)) - name = "sparseValuesI32"; - else if (eltType.isInteger(16)) - name = "sparseValuesI16"; - else if (eltType.isInteger(8)) - name = "sparseValuesI8"; - else - return failure(); + SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)}; auto fn = getFunc(op, name, resType, adaptor.getOperands(), /*emitCInterface=*/true); rewriter.replaceOpWithNewOp(op, resType, fn, adaptor.getOperands()); @@ -758,21 +692,7 @@ ConversionPatternRewriter &rewriter) const override { Type srcType = op.tensor().getType(); Type eltType = srcType.cast().getElementType(); - StringRef name; - if (eltType.isF64()) - name = "lexInsertF64"; - else if (eltType.isF32()) - name = "lexInsertF32"; - else if (eltType.isInteger(64)) - name = "lexInsertI64"; - else if (eltType.isInteger(32)) - name = "lexInsertI32"; - else if (eltType.isInteger(16)) - name = "lexInsertI16"; - else if (eltType.isInteger(8)) - name = "lexInsertI8"; - else - llvm_unreachable("Unknown element type"); + SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(eltType)}; TypeRange noTp; auto fn = getFunc(op, name, noTp, adaptor.getOperands(), /*emitCInterface=*/true); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Utils/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/CodegenUtils.cpp @@ -40,6 +40,14 @@ llvm_unreachable("Unsupported overhead bitwidth"); } +OverheadType overheadTypeEncoding(Type tp) { + if (tp.isIndex()) + return OverheadType::kIndex; + if (auto intTp = tp.dyn_cast()) + return overheadTypeEncoding(intTp.getWidth()); + llvm_unreachable("Unknown overhead type"); +} + Type getOverheadType(Builder &builder, OverheadType ot) { switch (ot) { case OverheadType::kIndex: @@ -67,6 +75,26 @@ return getOverheadType(builder, overheadTypeEncoding(enc.getIndexBitWidth())); } +StringRef overheadTypeFunctionSuffix(OverheadType ot) { + switch (ot) { + case OverheadType::kIndex: + return ""; + case OverheadType::kU64: + return "64"; + case OverheadType::kU32: + return "32"; + case OverheadType::kU16: + return "16"; + case OverheadType::kU8: + return "8"; + } + llvm_unreachable("Unknown OverheadType"); +} + +StringRef overheadTypeFunctionSuffix(Type tp) { + return overheadTypeFunctionSuffix(overheadTypeEncoding(tp)); +} + PrimaryType primaryTypeEncoding(Type elemTp) { if (elemTp.isF64()) return PrimaryType::kF64; @@ -83,6 +111,28 @@ llvm_unreachable("Unknown primary type"); } +StringRef primaryTypeFunctionSuffix(PrimaryType pt) { + switch (pt) { + case PrimaryType::kF64: + return "F64"; + case PrimaryType::kF32: + return "F32"; + case PrimaryType::kI64: + return "I64"; + case PrimaryType::kI32: + return "I32"; + case PrimaryType::kI16: + return "I16"; + case PrimaryType::kI8: + return "I8"; + } + llvm_unreachable("Unknown PrimaryType"); +} + +StringRef primaryTypeFunctionSuffix(Type elemTp) { + return primaryTypeFunctionSuffix(primaryTypeEncoding(elemTp)); +} + DimLevelType dimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) { switch (dlt) { case SparseTensorEncodingAttr::DimLevelType::Dense: