diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -32,6 +32,9 @@ /// Converts an overhead storage bitwidth to its internal type-encoding. OverheadType overheadTypeEncoding(unsigned width); +/// Converts an overhead storage type to its internal type-encoding. +OverheadType overheadTypeEncoding(Type tp); + /// Converts the internal type-encoding for overhead storage to an mlir::Type. Type getOverheadType(Builder &builder, OverheadType ot); @@ -43,9 +46,21 @@ Type getIndexOverheadType(Builder &builder, const SparseTensorEncodingAttr &enc); +/// Convert OverheadType to its function-name suffix. +StringRef overheadTypeFunctionSuffix(OverheadType ot); + +/// Converts an overhead storage type to its function-name suffix. +StringRef overheadTypeFunctionSuffix(Type overheadTp); + /// Converts a primary storage type to its internal type-encoding. PrimaryType primaryTypeEncoding(Type elemTp); +/// Convert PrimaryType to its function-name suffix. +StringRef primaryTypeFunctionSuffix(PrimaryType pt); + +/// Converts a primary storage type to its function-name suffix. +StringRef primaryTypeFunctionSuffix(Type elemTp); + /// Converts the IR's dimension level type to its internal type-encoding. DimLevelType dimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -34,6 +34,14 @@ llvm_unreachable("Unsupported overhead bitwidth"); } +OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) { + if (tp.isIndex()) + return OverheadType::kIndex; + if (auto intTp = tp.dyn_cast()) + return overheadTypeEncoding(intTp.getWidth()); + llvm_unreachable("Unknown overhead type"); +} + Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) { switch (ot) { case OverheadType::kIndex: @@ -61,6 +69,26 @@ return getOverheadType(builder, overheadTypeEncoding(enc.getIndexBitWidth())); } +StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) { + switch (ot) { + case OverheadType::kIndex: + return ""; + case OverheadType::kU64: + return "64"; + case OverheadType::kU32: + return "32"; + case OverheadType::kU16: + return "16"; + case OverheadType::kU8: + return "8"; + } + llvm_unreachable("Unknown OverheadType"); +} + +StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) { + return overheadTypeFunctionSuffix(overheadTypeEncoding(tp)); +} + PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) { if (elemTp.isF64()) return PrimaryType::kF64; @@ -77,6 +105,28 @@ llvm_unreachable("Unknown primary type"); } +StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) { + switch (pt) { + case PrimaryType::kF64: + return "F64"; + case PrimaryType::kF32: + return "F32"; + case PrimaryType::kI64: + return "I64"; + case PrimaryType::kI32: + return "I32"; + case PrimaryType::kI16: + return "I16"; + case PrimaryType::kI8: + return "I8"; + } + llvm_unreachable("Unknown PrimaryType"); +} + +StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) { + return primaryTypeFunctionSuffix(primaryTypeEncoding(elemTp)); +} + DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding( SparseTensorEncodingAttr::DimLevelType dlt) { switch (dlt) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -260,21 +260,7 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, Type eltType, Value ptr, Value val, Value ind, Value perm) { - StringRef name; - if (eltType.isF64()) - name = "addEltF64"; - else if (eltType.isF32()) - name = "addEltF32"; - else if (eltType.isInteger(64)) - name = "addEltI64"; - else if (eltType.isInteger(32)) - name = "addEltI32"; - else if (eltType.isInteger(16)) - name = "addEltI16"; - else if (eltType.isInteger(8)) - name = "addEltI8"; - else - llvm_unreachable("Unknown element type"); + SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)}; SmallVector params{ptr, val, ind, perm}; Type pTp = getOpaquePointerType(rewriter); createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On); @@ -287,21 +273,7 @@ static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op, Value iter, Value ind, Value elemPtr) { Type elemTp = elemPtr.getType().cast().getElementType(); - StringRef name; - if (elemTp.isF64()) - name = "getNextF64"; - else if (elemTp.isF32()) - name = "getNextF32"; - else if (elemTp.isInteger(64)) - name = "getNextI64"; - else if (elemTp.isInteger(32)) - name = "getNextI32"; - else if (elemTp.isInteger(16)) - name = "getNextI16"; - else if (elemTp.isInteger(8)) - name = "getNextI8"; - else - llvm_unreachable("Unknown element type"); + SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)}; SmallVector params{iter, ind, elemPtr}; Type i1 = rewriter.getI1Type(); return createFuncCall(rewriter, op, name, i1, params, EmitCInterface::On) @@ -668,20 +640,8 @@ matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resType = op.getType(); - Type eltType = resType.cast().getElementType(); - StringRef name; - if (eltType.isIndex()) - name = "sparsePointers"; - else if (eltType.isInteger(64)) - name = "sparsePointers64"; - else if (eltType.isInteger(32)) - name = "sparsePointers32"; - else if (eltType.isInteger(16)) - name = "sparsePointers16"; - else if (eltType.isInteger(8)) - name = "sparsePointers8"; - else - return failure(); + Type ptrType = resType.cast().getElementType(); + SmallString<16> name{"sparsePointers", overheadTypeFunctionSuffix(ptrType)}; replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), EmitCInterface::On); return success(); @@ -696,20 +656,8 @@ matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resType = op.getType(); - Type eltType = resType.cast().getElementType(); - StringRef name; - if (eltType.isIndex()) - name = "sparseIndices"; - else if (eltType.isInteger(64)) - name = "sparseIndices64"; - else if (eltType.isInteger(32)) - name = "sparseIndices32"; - else if (eltType.isInteger(16)) - name = "sparseIndices16"; - else if (eltType.isInteger(8)) - name = "sparseIndices8"; - else - return failure(); + Type indType = resType.cast().getElementType(); + SmallString<15> name{"sparseIndices", overheadTypeFunctionSuffix(indType)}; replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), EmitCInterface::On); return success(); @@ -725,21 +673,7 @@ ConversionPatternRewriter &rewriter) const override { Type resType = op.getType(); Type eltType = resType.cast().getElementType(); - StringRef name; - if (eltType.isF64()) - name = "sparseValuesF64"; - else if (eltType.isF32()) - name = "sparseValuesF32"; - else if (eltType.isInteger(64)) - name = "sparseValuesI64"; - else if (eltType.isInteger(32)) - name = "sparseValuesI32"; - else if (eltType.isInteger(16)) - name = "sparseValuesI16"; - else if (eltType.isInteger(8)) - name = "sparseValuesI8"; - else - return failure(); + SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)}; replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), EmitCInterface::On); return success(); @@ -772,23 +706,8 @@ LogicalResult matchAndRewrite(LexInsertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type srcType = op.tensor().getType(); - Type eltType = srcType.cast().getElementType(); - StringRef name; - if (eltType.isF64()) - name = "lexInsertF64"; - else if (eltType.isF32()) - name = "lexInsertF32"; - else if (eltType.isInteger(64)) - name = "lexInsertI64"; - else if (eltType.isInteger(32)) - name = "lexInsertI32"; - else if (eltType.isInteger(16)) - name = "lexInsertI16"; - else if (eltType.isInteger(8)) - name = "lexInsertI8"; - else - llvm_unreachable("Unknown element type"); + Type elemTp = op.tensor().getType().cast().getElementType(); + SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; TypeRange noTp; replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), EmitCInterface::On); @@ -843,23 +762,8 @@ // all-zero/false by only iterating over the set elements, so the // complexity remains proportional to the sparsity of the expanded // access pattern. - Type srcType = op.tensor().getType(); - Type eltType = srcType.cast().getElementType(); - StringRef name; - if (eltType.isF64()) - name = "expInsertF64"; - else if (eltType.isF32()) - name = "expInsertF32"; - else if (eltType.isInteger(64)) - name = "expInsertI64"; - else if (eltType.isInteger(32)) - name = "expInsertI32"; - else if (eltType.isInteger(16)) - name = "expInsertI16"; - else if (eltType.isInteger(8)) - name = "expInsertI8"; - else - return failure(); + Type elemTp = op.tensor().getType().cast().getElementType(); + SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)}; TypeRange noTp; replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), EmitCInterface::On);