diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -177,6 +177,11 @@ return dlt == DimLevelType::Dense; } +/// Strip the property bits from the `DimLevelType` +constexpr DimLevelType stripLevelProperty(DimLevelType dlt) { + return static_cast(static_cast(dlt) & ~3); +} + // We use the idiom `(dlt & ~3) == format` in order to only return true // for valid DLTs. Whereas the `dlt & format` idiom is a bit faster but // can return false-positives on invalid DLTs. diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td @@ -54,7 +54,7 @@ let builders = [ TypeBuilderWithInferredContext<(ins "SparseTensorEncodingAttr":$encoding), [{ assert(encoding && "sparse tensor encoding should not be null"); - return $_get(encoding.getContext(), encoding); + return get(encoding.getContext(), encoding); }]>, TypeBuilderWithInferredContext<(ins "Type":$type), [{ return get(getSparseTensorEncoding(type)); @@ -69,8 +69,10 @@ IntegerType getSizesType() const; Type getFieldType(StorageSpecifierKind kind, Optional dim) const; Type getFieldType(StorageSpecifierKind kind, Optional dim) const; + static StorageSpecifierType get(MLIRContext *ctx, SparseTensorEncodingAttr enc); }]; - + + let skipDefaultBuilders = 1; let assemblyFormat="`<` qualified($encoding) `>`"; } diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -186,6 +186,15 @@ bool enableVLAVectorization, bool enableSIMDIndex32); +class SparseSpecifierToLLVMTypeConverter : public TypeConverter { +public: + SparseSpecifierToLLVMTypeConverter(); +}; + +void populateSparseSpecifierToLLVMPatterns(TypeConverter &converter, + RewritePatternSet &patterns); +std::unique_ptr createSparseSpecifierToLLVMPass(); + //===----------------------------------------------------------------------===// // Registration. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -301,4 +301,16 @@ ]; } +def SparseSpecifierToLLVM : Pass<"sparse-specifier-to-llvm", "ModuleOp"> { + let summary = "Rewrite sparse primitives on buffers to actual code"; + let description = [{ + }]; + let constructor = "mlir::createSparseSpecifierToLLVMPass()"; + let dependentDialects = [ + "arith::ArithDialect", + "LLVM::LLVMDialect", + "sparse_tensor::SparseTensorDialect", + ]; +} + #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -341,6 +341,27 @@ // SparseTensorDialect Types. //===----------------------------------------------------------------------===// +static SparseTensorEncodingAttr +getCanonicalizedEncoding(SparseTensorEncodingAttr enc) { + SmallVector dlts; + for (auto dlt : enc.getDimLevelType()) + dlts.push_back(stripLevelProperty(dlt)); + + AffineMap dimOrder = + enc.getDimOrdering() && !enc.getDimOrdering().isIdentity() + ? enc.getDimOrdering() + : AffineMap(); + + return SparseTensorEncodingAttr::get( + enc.getContext(), dlts, dimOrder, enc.getHigherOrdering(), + enc.getPointerBitWidth(), enc.getIndexBitWidth()); +} + +StorageSpecifierType StorageSpecifierType::get(MLIRContext *ctx, + SparseTensorEncodingAttr enc) { + return Base::get(ctx, getCanonicalizedEncoding(enc)); +} + IntegerType StorageSpecifierType::getSizesType() const { unsigned idxBitWidth = getEncoding().getIndexBitWidth() ? getEncoding().getIndexBitWidth() : 64u; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -2,6 +2,8 @@ BufferizableOpInterfaceImpl.cpp CodegenUtils.cpp SparseBufferRewriting.cpp + SparseSpecifierToLLVM.cpp + SparseTensorBuilder.cpp SparseTensorCodegen.cpp SparseTensorConversion.cpp SparseTensorPasses.cpp diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpecifierToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpecifierToLLVM.cpp @@ -0,0 +1,184 @@ +//===- SparseVectorization.cpp - Vectorization of sparsified loops --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "CodegenUtils.h" +#include "SparseTensorBuilder.h" + +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" + +using namespace mlir; +using namespace sparse_tensor; + +static SmallVector getSpecifierFields(StorageSpecifierType tp) { + MLIRContext *ctx = tp.getContext(); + auto enc = tp.getEncoding(); + unsigned rank = enc.getDimLevelType().size(); + + SmallVector result; + auto indexType = tp.getSizesType(); + auto dimSizes = LLVM::LLVMArrayType::get(ctx, indexType, rank); + auto memSizes = LLVM::LLVMArrayType::get(ctx, indexType, + getNumDataFieldsFromEncoding(enc)); + result.push_back(dimSizes); + result.push_back(memSizes); + return result; +} + +static Type convertSpecifier(StorageSpecifierType tp) { + return LLVM::LLVMStructType::getLiteral(tp.getContext(), + getSpecifierFields(tp)); +} + +SparseSpecifierToLLVMTypeConverter::SparseSpecifierToLLVMTypeConverter() { + addConversion([](Type type) { return type; }); + addConversion([](StorageSpecifierType tp) { return convertSpecifier(tp); }); +} + +constexpr uint64_t kDimSizePosInSpecifier = 0; +constexpr uint64_t kMemSizePosInSpecifier = 1; + +class SpecifierStructBuilder : public StructBuilder { +public: + explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) { + assert(value); + } + + // Undef value for dimension sizes, all zero value for memory sizes. + static Value getInitValue(OpBuilder &builder, Location loc, Type structType); + + Value dimSize(OpBuilder &builder, Location loc, unsigned dim); + void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value size); + + Value memSize(OpBuilder &builder, Location loc, unsigned pos); + void setMemSize(OpBuilder &builder, Location loc, unsigned pos, Value size); +}; + +Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc, + Type structType) { + Value metaData = builder.create(loc, structType); + SpecifierStructBuilder md(metaData); + auto memSizeArrayType = structType.cast() + .getBody()[kMemSizePosInSpecifier] + .cast(); + + Value zero = constantZero(builder, loc, memSizeArrayType.getElementType()); + // Fill memSizes array with zero. + + for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++) + md.setMemSize(builder, loc, i, zero); + + return md; +} + +/// Builds IR inserting the pos-th size into the descriptor. +Value SpecifierStructBuilder::dimSize(OpBuilder &builder, Location loc, + unsigned dim) { + return builder.create( + loc, value, ArrayRef({kDimSizePosInSpecifier, dim})); +} + +/// Builds IR inserting the pos-th size into the descriptor. +void SpecifierStructBuilder::setDimSize(OpBuilder &builder, Location loc, + unsigned dim, Value size) { + value = builder.create( + loc, value, size, ArrayRef({kDimSizePosInSpecifier, dim})); +} + +/// Builds IR extracting the pos-th memory size into the descriptor. +Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc, + unsigned pos) { + return builder.create( + loc, value, ArrayRef({kMemSizePosInSpecifier, pos})); +} + +/// Builds IR inserting the pos-th memory size into the descriptor. +void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc, + unsigned pos, Value size) { + value = builder.create( + loc, value, size, ArrayRef({kMemSizePosInSpecifier, pos})); +} + +template +class SpecifierGetterSetterOpConverter : public OpConversionPattern { +public: + using OpAdaptor = typename SourceOp::Adaptor; + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SpecifierStructBuilder spec(adaptor.getSpecifier()); + Value v; + if (op.getSpecifierKind() == StorageSpecifierKind::DimSize) { + v = Base::onDimSize(rewriter, op, spec, + op.getDim().value().getZExtValue()); + } else { + auto enc = op.getSpecifier().getType().getEncoding(); + builder::StorageLayout layout(enc); + Optional dim = std::nullopt; + if (op.getDim()) + dim = op.getDim().value().getZExtValue(); + unsigned idx = layout.getMemRefFieldIndex(op.getSpecifierKind(), dim); + v = Base::onMemSize(rewriter, op, spec, idx); + } + + rewriter.replaceOp(op, v); + return success(); + } +}; + +struct SpecifierSetOpConverter + : public SpecifierGetterSetterOpConverter { + using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter; + static Value onDimSize(OpBuilder &builder, SetStorageSpecifierOp op, + SpecifierStructBuilder &spec, unsigned d) { + spec.setDimSize(builder, op.getLoc(), d, op.getValue()); + return spec; + } + + static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op, + SpecifierStructBuilder &spec, unsigned i) { + spec.setMemSize(builder, op.getLoc(), i, op.getValue()); + return spec; + } +}; + +struct SpecifierGetOpConverter + : public SpecifierGetterSetterOpConverter { + using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter; + static Value onDimSize(OpBuilder &builder, GetStorageSpecifierOp op, + SpecifierStructBuilder &spec, unsigned d) { + return spec.dimSize(builder, op.getLoc(), d); + } + static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op, + SpecifierStructBuilder &spec, unsigned i) { + return spec.memSize(builder, op.getLoc(), i); + } +}; + +struct SpecifierInitOpConverter + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type llvmType = getTypeConverter()->convertType(op.getResult().getType()); + rewriter.replaceOp(op, SpecifierStructBuilder::getInitValue( + rewriter, op.getLoc(), llvmType)); + return success(); + } +}; + +void mlir::populateSparseSpecifierToLLVMPatterns(TypeConverter &converter, + RewritePatternSet &patterns) { + patterns.add(converter, patterns.getContext()); +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.h @@ -0,0 +1,361 @@ +//===- SparseTensorBuilder.h ------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines utilities for lowering and access sparse tensor +// types. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_ +#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_ + +#include "mlir/Conversion/LLVMCommon/StructBuilder.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace sparse_tensor { + +// FIXME: this is a tmp namespace +namespace builder { +//===----------------------------------------------------------------------===// +// SparseTensorDescriptor and helpers, manage the sparse tensor memory layout +// scheme. +// +// Sparse tensor storage scheme for rank-dimensional tensor is organized +// as a single compound type with the following fields. Note that every +// memref with ? size actually behaves as a "vector", i.e. the stored +// size is the capacity and the used size resides in the memSizes array. +// +// struct { +// ; per-dimension d: +// ; if dense: +// +// ; if compresed: +// memref pointers-d ; pointers for sparse dim d +// memref indices-d ; indices for sparse dim d +// ; if singleton: +// memref indices-d ; indices for singleton dim d +// memref values ; values +// +// ; sparse tensor metadata +// struct { +// array dimSizes ; sizes for each dimension +// array memSizes; ; sizes for each data memref +// } +// }; +// +//===----------------------------------------------------------------------===// +enum class SparseTensorFieldKind : uint32_t { + StorageSpec = 0, + PtrMemRef = 1, + IdxMemRef = 2, + ValMemRef = 3 +}; + +static_assert(static_cast(SparseTensorFieldKind::PtrMemRef) == + static_cast(StorageSpecifierKind::PtrMemSize)); +static_assert(static_cast(SparseTensorFieldKind::IdxMemRef) == + static_cast(StorageSpecifierKind::IdxMemSize)); +static_assert(static_cast(SparseTensorFieldKind::ValMemRef) == + static_cast(StorageSpecifierKind::ValMemSize)); + +/// For each field that will be allocated for the given sparse tensor encoding, +/// calls the callback with the corresponding field index, field kind, dimension +/// (for sparse tensor level memrefs) and dimlevelType. +/// The field index always starts with zero and increments by one between two +/// callback invocations. +/// Ideally, all other methods should rely on this function to query a sparse +/// tensor fields instead of relying on ad-hoc index computation. +void foreachFieldInSparseTensor( + SparseTensorEncodingAttr, + llvm::function_ref); + +/// Same as above, except that it also builds the Type for the corresponding +/// field. +void foreachFieldAndTypeInSparseTensor( + RankedTensorType, + llvm::function_ref); + +/// Gets the total number of fields for the given sparse tensor encoding. +unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc); + +/// Gets the total number of data fields (index arrays, pointer arrays, and a +/// value array) for the given sparse tensor encoding. +unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc); + +inline StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind) { + assert(kind != SparseTensorFieldKind::StorageSpec); + return static_cast(kind); +} + +inline SparseTensorFieldKind toFieldKind(StorageSpecifierKind kind) { + assert(kind != StorageSpecifierKind::DimSize); + return static_cast(kind); +} + +class StorageLayout { +public: + explicit StorageLayout(SparseTensorEncodingAttr enc) : enc(enc) {} + + /// + /// Getters: get the field index for required field. + /// + unsigned getMemRefFieldIndex(SparseTensorFieldKind kind, + Optional dim) const; + + unsigned getMemRefFieldIndex(StorageSpecifierKind kind, + Optional dim) const; + +private: + unsigned getFieldIndex(unsigned dim, SparseTensorFieldKind kind) const; + SparseTensorEncodingAttr enc; +}; + +class SparseTensorSpecifier { +public: + explicit SparseTensorSpecifier(Value specifier) : specifier(specifier) {} + + // Undef value for dimension sizes, all zero value for memory sizes. + static Value getInitValue(OpBuilder &builder, Location loc, + RankedTensorType rtp); + + /*implicit*/ operator Value() { return specifier; } + + Value getSpecifierField(OpBuilder &builder, Location loc, + StorageSpecifierKind kind, Optional dim); + + void setSpecifierField(OpBuilder &builder, Location loc, Value v, + StorageSpecifierKind kind, Optional dim); + + Type getFieldType(StorageSpecifierKind kind, Optional dim) { + return specifier.getType().getFieldType(kind, dim); + } + +private: + TypedValue specifier; +}; + +/// A helper class around an array of values that corresponding to a sparse +/// tensor, provides a set of meaningful APIs to query and update a particular +/// field in a consistent way. +/// Users should not make assumption on how a sparse tensor is laid out but +/// instead relies on this class to access the right value for the right field. +template +class SparseTensorDescriptorImpl { +private: + // Uses ValueRange for immuatable descriptors; uses SmallVectorImpl & + // for mutable descriptors. + // Using SmallVector for mutable descriptor allows users to reuse it as a tmp + // buffers to append value for some special cases, though users should be + // responsible to restore the buffer to legal states after their use. It is + // probably not a clean way, but it is the most efficient way to avoid copying + // the fields into another SmallVector. If a more clear way is wanted, we + // should change it to MutableArrayRef instead. + using ValueArrayRef = typename std::conditional &, + ValueRange>::type; + +public: + SparseTensorDescriptorImpl(Type tp, ValueArrayRef fields) + : rType(tp.cast()), fields(fields) { + assert(getSparseTensorEncoding(tp) && + builder::getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) == + fields.size()); + // We should make sure the class is trivially copyable (and should be small + // enough) such that we can pass it by value. + static_assert( + std::is_trivially_copyable_v>); + } + + // Implicit (and cheap) type conversion from MutSparseTensorDescriptor to + // SparseTensorDescriptor. + template > + /*implicit*/ SparseTensorDescriptorImpl(std::enable_if_t &mDesc) + : rType(mDesc.getTensorType()), fields(mDesc.getFields()) {} + + unsigned getMemRefFieldIndex(SparseTensorFieldKind kind, + Optional dim) const { + // Delegates to storage layout. + StorageLayout layout(getSparseTensorEncoding(rType)); + return layout.getMemRefFieldIndex(kind, dim); + } + + unsigned getPtrMemRefIndex(unsigned ptrDim) const { + return getMemRefFieldIndex(SparseTensorFieldKind::PtrMemRef, ptrDim); + } + + unsigned getIdxMemRefIndex(unsigned idxDim) const { + return getMemRefFieldIndex(SparseTensorFieldKind::IdxMemRef, idxDim); + } + + unsigned getValMemRefIndex() const { + return getMemRefFieldIndex(SparseTensorFieldKind::ValMemRef, std::nullopt); + } + + unsigned getNumFields() const { return fields.size(); } + + /// + /// Getters: get the value for required field. + /// + + Value getSpecifierField(OpBuilder &builder, Location loc, + StorageSpecifierKind kind, + Optional dim) const { + SparseTensorSpecifier md(fields.back()); + return md.getSpecifierField(builder, loc, kind, dim); + } + + Value getDimSize(OpBuilder &builder, Location loc, unsigned dim) const { + return getSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim); + } + + Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const { + return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize, + dim); + } + + Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const { + return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize, + dim); + } + + Value getValMemSize(OpBuilder &builder, Location loc) const { + return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize, + std::nullopt); + } + + Value getPtrMemRef(unsigned ptrDim) const { + return getMemRefField(SparseTensorFieldKind::PtrMemRef, ptrDim); + } + + Value getIdxMemRef(unsigned idxDim) const { + return getMemRefField(SparseTensorFieldKind::IdxMemRef, idxDim); + } + + Value getValMemRef() const { + return getMemRefField(SparseTensorFieldKind::ValMemRef, std::nullopt); + } + + Value getMemRefField(SparseTensorFieldKind kind, + Optional dim) const { + return fields[getMemRefFieldIndex(kind, dim)]; + } + + Value getMemRefField(unsigned fidx) const { + assert(fidx < fields.size() - 1); + return fields[fidx]; + } + + Value getField(unsigned fidx) const { + assert(fidx < fields.size()); + return fields[fidx]; + } + + /// + /// Setters: update the value for required field (only enabled for + /// MutSparseTensorDescriptor). + /// + + template + void setMemRefField(SparseTensorFieldKind kind, Optional dim, + std::enable_if_t v) { + fields[getMemRefFieldIndex(kind, dim)] = v; + } + + template + void setMemRefField(unsigned fidx, std::enable_if_t v) { + assert(fidx < fields.size() - 1); + fields[fidx] = v; + } + + template + void setField(unsigned fidx, std::enable_if_t v) { + assert(fidx < fields.size()); + fields[fidx] = v; + } + + template + void setSpecifierField(OpBuilder &builder, Location loc, + StorageSpecifierKind kind, Optional dim, + std::enable_if_t v) { + SparseTensorSpecifier md(fields.back()); + md.setSpecifierField(builder, loc, v, kind, dim); + fields.back() = md; + } + + template + void setDimSize(OpBuilder &builder, Location loc, unsigned dim, + std::enable_if_t v) { + setSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim, v); + } + + ValueRange getMemRefFields() const { + ValueRange ret = fields; + // drop the last metadata fields + return ret.slice(0, fields.size() - 1); + } + + Type getMemRefElementType(SparseTensorFieldKind kind, + Optional dim) const { + return getMemRefField(kind, dim) + .getType() + .template cast() + .getElementType(); + } + + RankedTensorType getTensorType() const { return rType; } + ValueArrayRef getFields() const { return fields; } + +private: + RankedTensorType rType; + ValueArrayRef fields; +}; + +using SparseTensorDescriptor = SparseTensorDescriptorImpl; +using MutSparseTensorDescriptor = SparseTensorDescriptorImpl; + +/// Returns the "tuple" value of the adapted tensor. +inline UnrealizedConversionCastOp getTuple(Value tensor) { + return llvm::cast(tensor.getDefiningOp()); +} + +/// Packs the given values as a "tuple" value. +inline Value genTuple(OpBuilder &builder, Location loc, Type tp, + ValueRange values) { + return builder.create(loc, TypeRange(tp), values) + .getResult(0); +} + +inline Value genTuple(OpBuilder &builder, Location loc, + SparseTensorDescriptor desc) { + return genTuple(builder, loc, desc.getTensorType(), desc.getFields()); +} + +inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) { + auto tuple = getTuple(tensor); + return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs()); +} + +inline MutSparseTensorDescriptor +getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl &fields) { + auto tuple = getTuple(tensor); + fields.assign(tuple.getInputs().begin(), tuple.getInputs().end()); + return MutSparseTensorDescriptor(tuple.getResultTypes()[0], fields); +} + +} // namespace builder +} // namespace sparse_tensor +} // namespace mlir +#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_ diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.cpp @@ -0,0 +1,188 @@ +//===- SparseTensorBuilder.cpp --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "SparseTensorBuilder.h" +#include "CodegenUtils.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace sparse_tensor; + +static Value createIndexCast(OpBuilder &builder, Location loc, Value value, + Type to) { + if (value.getType() != to) + return builder.create(loc, to, value); + return value; +} + +static IntegerAttr fromOptionalInt(MLIRContext *ctx, Optional dim) { + if (!dim) + return nullptr; + return IntegerAttr::get(IndexType::get(ctx), dim.value()); +} + +unsigned +builder::StorageLayout::getMemRefFieldIndex(SparseTensorFieldKind kind, + Optional dim) const { + unsigned fieldIdx = -1u; + foreachFieldInSparseTensor( + enc, + [dim, kind, &fieldIdx](unsigned fIdx, SparseTensorFieldKind fKind, + unsigned fDim, DimLevelType dlt) -> bool { + if ((dim && fDim == dim.value() && kind == fKind) || + (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) { + fieldIdx = fIdx; + // Returns false to break the iteration. + return false; + } + return true; + }); + assert(fieldIdx != -1u); + return fieldIdx; +} + +unsigned +builder::StorageLayout::getMemRefFieldIndex(StorageSpecifierKind kind, + Optional dim) const { + return getMemRefFieldIndex(toFieldKind(kind), dim); +} + +Value builder::SparseTensorSpecifier::getInitValue(OpBuilder &builder, + Location loc, + RankedTensorType rtp) { + return builder.create( + loc, StorageSpecifierType::get(getSparseTensorEncoding(rtp))); +} + +Value builder::SparseTensorSpecifier::getSpecifierField( + OpBuilder &builder, Location loc, StorageSpecifierKind kind, + Optional dim) { + return createIndexCast(builder, loc, + builder.create( + loc, getFieldType(kind, dim), specifier, kind, + fromOptionalInt(specifier.getContext(), dim)), + builder.getIndexType()); +} + +void builder::SparseTensorSpecifier::setSpecifierField( + OpBuilder &builder, Location loc, Value v, StorageSpecifierKind kind, + Optional dim) { + specifier = builder.create( + loc, specifier, kind, fromOptionalInt(specifier.getContext(), dim), + createIndexCast(builder, loc, v, getFieldType(kind, dim))); +} + +constexpr uint64_t kDataFieldStartingIdx = 0; + +void sparse_tensor::builder::foreachFieldInSparseTensor( + const SparseTensorEncodingAttr enc, + llvm::function_ref + callback) { + assert(enc); + +#define RETURN_ON_FALSE(idx, kind, dim, dlt) \ + if (!(callback(idx, kind, dim, dlt))) \ + return; + + static_assert(kDataFieldStartingIdx == 0); + unsigned fieldIdx = kDataFieldStartingIdx; + // Per-dimension storage. + for (unsigned r = 0, rank = enc.getDimLevelType().size(); r < rank; r++) { + // Dimension level types apply in order to the reordered dimension. + // As a result, the compound type can be constructed directly in the given + // order. + auto dlt = getDimLevelType(enc, r); + if (isCompressedDLT(dlt)) { + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt); + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); + } else if (isSingletonDLT(dlt)) { + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); + } else { + assert(isDenseDLT(dlt)); // no fields + } + } + // The values array. + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u, + DimLevelType::Undef); + + // Put metadata at the end. + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::StorageSpec, -1u, + DimLevelType::Undef); + +#undef RETURN_ON_FALSE +} + +void sparse_tensor::builder::foreachFieldAndTypeInSparseTensor( + RankedTensorType rType, + llvm::function_ref + callback) { + auto enc = getSparseTensorEncoding(rType); + assert(enc); + // Construct the basic types. + Type idxType = enc.getIndexType(); + Type ptrType = enc.getPointerType(); + Type eltType = rType.getElementType(); + + Type metaDataType = StorageSpecifierType::get(enc); + // memref pointers + Type ptrMemType = MemRefType::get({ShapedType::kDynamic}, ptrType); + // memref indices + Type idxMemType = MemRefType::get({ShapedType::kDynamic}, idxType); + // memref values + Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType); + + foreachFieldInSparseTensor( + enc, + [metaDataType, ptrMemType, idxMemType, valMemType, + callback](unsigned fieldIdx, SparseTensorFieldKind fieldKind, + unsigned dim, DimLevelType dlt) -> bool { + switch (fieldKind) { + case SparseTensorFieldKind::StorageSpec: + return callback(metaDataType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::PtrMemRef: + return callback(ptrMemType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::IdxMemRef: + return callback(idxMemType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::ValMemRef: + return callback(valMemType, fieldIdx, fieldKind, dim, dlt); + }; + llvm_unreachable("unrecognized field kind"); + }); +} + +unsigned +sparse_tensor::builder::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { + unsigned numFields = 0; + foreachFieldInSparseTensor(enc, + [&numFields](unsigned, SparseTensorFieldKind, + unsigned, DimLevelType) -> bool { + numFields++; + return true; + }); + return numFields; +} + +unsigned sparse_tensor::builder::getNumDataFieldsFromEncoding( + SparseTensorEncodingAttr enc) { + unsigned numFields = 0; // one value memref + foreachFieldInSparseTensor(enc, + [&numFields](unsigned fidx, SparseTensorFieldKind, + unsigned, DimLevelType) -> bool { + if (fidx >= kDataFieldStartingIdx) + numFields++; + return true; + }); + numFields -= 1; // the last field is MetaData field + assert(numFields == + builder::getNumFieldsFromEncoding(enc) - kDataFieldStartingIdx - 1); + return numFields; +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -28,6 +28,7 @@ #define GEN_PASS_DEF_SPARSETENSORCODEGEN #define GEN_PASS_DEF_SPARSEBUFFERREWRITE #define GEN_PASS_DEF_SPARSEVECTORIZATION +#define GEN_PASS_DEF_SPARSESPECIFIERTOLLVM #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" } // namespace mlir @@ -193,9 +194,14 @@ target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); - // All dynamic rules below accept new function, call, return, and various - // tensor and bufferization operations as legal output of the rewriting - // provided that all sparse tensor types have been fully rewritten. + // Storage specifier outlives sparse tensor pipeline. + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + // All dynamic rules below accept new function, call, return, and + // various tensor and bufferization operations as legal output of the + // rewriting provided that all sparse tensor types have been fully + // rewritten. target.addDynamicallyLegalOp([&](func::FuncOp op) { return converter.isSignatureLegal(op.getFunctionType()); }); @@ -271,6 +277,44 @@ } }; +struct SparseSpecifierToLLVMPass + : public impl::SparseSpecifierToLLVMBase { + + SparseSpecifierToLLVMPass() = default; + + void runOnOperation() override { + auto *ctx = &getContext(); + ConversionTarget target(*ctx); + RewritePatternSet patterns(ctx); + SparseSpecifierToLLVMTypeConverter converter; + + // All ops in the sparse dialect must go! + target.addIllegalDialect(); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return converter.isSignatureLegal(op.getFunctionType()); + }); + target.addDynamicallyLegalOp([&](func::CallOp op) { + return converter.isSignatureLegal(op.getCalleeType()); + }); + target.addDynamicallyLegalOp([&](func::ReturnOp op) { + return converter.isLegal(op.getOperandTypes()); + }); + target.addLegalDialect(); + + populateFunctionOpInterfaceTypeConversionPattern(patterns, + converter); + populateCallOpTypeConversionPattern(patterns, converter); + populateBranchOpInterfaceTypeConversionPattern(patterns, converter); + populateReturnOpTypeConversionPattern(patterns, converter); + scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, + target); + populateSparseSpecifierToLLVMPatterns(converter, patterns); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -355,3 +399,7 @@ return std::make_unique( vectorLength, enableVLAVectorization, enableSIMDIndex32); } + +std::unique_ptr mlir::createSparseSpecifierToLLVMPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -147,6 +147,7 @@ } else { pm.addPass(createSparseTensorCodegenPass(enableBufferInitialization)); pm.addPass(createSparseBufferRewritePass(enableBufferInitialization)); + pm.addPass(createSparseSpecifierToLLVMPass()); } if (failed(runPipeline(pm, getOperation()))) return signalPassFailure();