diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -177,6 +177,11 @@ return dlt == DimLevelType::Dense; } +/// Strip the property bits from the `DimLevelType` +constexpr DimLevelType stripLevelProperty(DimLevelType dlt) { + return static_cast(static_cast(dlt) & ~3); +} + // We use the idiom `(dlt & ~3) == format` in order to only return true // for valid DLTs. Whereas the `dlt & format` idiom is a bit faster but // can return false-positives on invalid DLTs. diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -342,17 +342,18 @@ "inBuffer", "value", "$_self.cast().getElementType()">, AllTypesMatch<["inBuffer", "outBuffer"]>]>, - Arguments<(ins StridedMemRefRankOf<[Index], [1]>:$bufferSizes, + Arguments<(ins Index:$curSize, StridedMemRefRankOf<[AnyType], [1]>:$inBuffer, - AnyType:$value, IndexAttr:$idx, Optional:$n, + AnyType:$value, Optional:$n, UnitAttr:$inbounds)>, - Results<(outs StridedMemRefRankOf<[AnyType], [1]>:$outBuffer)> { + Results<(outs StridedMemRefRankOf<[AnyType], [1]>:$outBuffer, + Index:$newSize)> { string summary = "Pushes a value to the back of a given buffer"; string description = [{ Push `value` to the end of the given sparse tensor storage buffer - `inBuffer` and update the size of the buffer in `bufferSizes[idx]`. The - capacity of the buffer is recorded in the memref type of `inBuffer `. If the - current buffer is full, then `inBuffer.realloc` is called before pushing the + `inBuffer` according to `curSize` and return the new size of the buffer in + `newSize`. The capacity of the buffer is recorded in the memref type of `inBuffer`. + If the current buffer is full, then `inBuffer.realloc` is called before pushing the data to the buffer. This is similar to std::vector push_back. The optional input `n` specifies the number of times to repeately push @@ -375,29 +376,28 @@ Example: ```mlir - %r = sparse_tensor.push_back %bufferSizes, %buffer, %val - {idx = 0 : index} : memref, memref, f64 + %buf, %newSize = sparse_tensor.push_back %curSize, %buffer, %val + : index, memref, f64 ``` ```mlir - %r = sparse_tensor.push_back inbounds %bufferSizes, %buffer, %val - {idx = 0 : index} : memref, memref, f64 + %buf, %newSize = sparse_tensor.push_back inbounds %curSize, %buffer, %val + : xindex, memref, f64 ``` ```mlir - %r = sparse_tensor.push_back inbounds %bufferSizes, %buffer, %val, %n - {idx = 0 : index} : memref, memref, f64 + %buf, %newSize = sparse_tensor.push_back inbounds %curSize, %buffer, %val, %n + : xindex, memref, f64 ``` }]; - let assemblyFormat = "(`inbounds` $inbounds^)? $bufferSizes `,` $inBuffer" + let assemblyFormat = "(`inbounds` $inbounds^)? $curSize `,` $inBuffer" " `,` $value (`,` $n^ )? attr-dict `:`" - " type($bufferSizes) `,` type($inBuffer) `,`" - " type($value) (`,` type($n)^ )?"; + " type($curSize) `,` type($inBuffer) `,`" + " type($value) (`,` type($n)^ )?"; let builders = [ - //Build an op without input `n`. - OpBuilder<(ins "Type":$outBuffer, "Value":$bufferSizes, "Value":$inBuffer, - "Value":$value, "APInt":$idx)> + //Build an op (reusing type from curSize and inBuffer) without input `n` + OpBuilder<(ins "Value":$curSize, "Value":$inBuffer, "Value":$value)> ]; let hasVerifier = 1; diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td @@ -54,7 +54,7 @@ let builders = [ TypeBuilderWithInferredContext<(ins "SparseTensorEncodingAttr":$encoding), [{ assert(encoding && "sparse tensor encoding should not be null"); - return $_get(encoding.getContext(), encoding); + return get(encoding.getContext(), encoding); }]>, TypeBuilderWithInferredContext<(ins "Type":$type), [{ return get(getSparseTensorEncoding(type)); @@ -69,8 +69,10 @@ IntegerType getSizesType() const; Type getFieldType(StorageSpecifierKind kind, Optional dim) const; Type getFieldType(StorageSpecifierKind kind, Optional dim) const; + static StorageSpecifierType get(MLIRContext *ctx, SparseTensorEncodingAttr enc); }]; - + + let skipDefaultBuilders = 1; let assemblyFormat="`<` qualified($encoding) `>`"; } diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -122,10 +122,12 @@ // The SparseTensorCodegen pass. //===----------------------------------------------------------------------===// -/// Sparse tensor type converter into an actual buffer. class SparseTensorTypeToBufferConverter : public TypeConverter { public: SparseTensorTypeToBufferConverter(); + + Optional + convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl &fields); }; /// Sets up sparse tensor conversion rules. @@ -186,6 +188,15 @@ bool enableVLAVectorization, bool enableSIMDIndex32); +class SparseSpecifierToLLVMTypeConverter : public TypeConverter { +public: + SparseSpecifierToLLVMTypeConverter(); +}; + +void populateSparseSpecifierToLLVMPatterns(TypeConverter &converter, + RewritePatternSet &patterns); +std::unique_ptr createSparseSpecifierToLLVMPass(); + //===----------------------------------------------------------------------===// // Registration. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -301,4 +301,16 @@ ]; } +def SparseSpecifierToLLVM : Pass<"sparse-specifier-to-llvm", "ModuleOp"> { + let summary = "Rewrite sparse primitives on buffers to actual code"; + let description = [{ + }]; + let constructor = "mlir::createSparseSpecifierToLLVMPass()"; + let dependentDialects = [ + "arith::ArithDialect", + "LLVM::LLVMDialect", + "sparse_tensor::SparseTensorDialect", + ]; +} + #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -341,6 +341,27 @@ // SparseTensorDialect Types. //===----------------------------------------------------------------------===// +static SparseTensorEncodingAttr +getCanonicalizedEncoding(SparseTensorEncodingAttr enc) { + SmallVector dlts; + for (auto dlt : enc.getDimLevelType()) + dlts.push_back(stripLevelProperty(dlt)); + + AffineMap dimOrder = + enc.getDimOrdering() && !enc.getDimOrdering().isIdentity() + ? enc.getDimOrdering() + : AffineMap(); + + return SparseTensorEncodingAttr::get( + enc.getContext(), dlts, dimOrder, enc.getHigherOrdering(), + enc.getPointerBitWidth(), enc.getIndexBitWidth()); +} + +StorageSpecifierType StorageSpecifierType::get(MLIRContext *ctx, + SparseTensorEncodingAttr enc) { + return Base::get(ctx, getCanonicalizedEncoding(enc)); +} + IntegerType StorageSpecifierType::getSizesType() const { unsigned idxBitWidth = getEncoding().getIndexBitWidth() ? getEncoding().getIndexBitWidth() : 64u; @@ -677,10 +698,8 @@ } void PushBackOp::build(OpBuilder &builder, OperationState &result, - Type outBuffer, Value bufferSizes, Value inBuffer, - Value value, APInt idx) { - build(builder, result, outBuffer, bufferSizes, inBuffer, value, - std::move(idx), Value()); + Value curSize, Value inBuffer, Value value) { + build(builder, result, curSize, inBuffer, value, Value()); } LogicalResult PushBackOp::verify() { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -2,6 +2,8 @@ BufferizableOpInterfaceImpl.cpp CodegenUtils.cpp SparseBufferRewriting.cpp + SparseSpecifierToLLVM.cpp + SparseTensorBuilder.cpp SparseTensorCodegen.cpp SparseTensorConversion.cpp SparseTensorPasses.cpp diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -308,220 +308,6 @@ return !rtp || rtp.getRank() == 0; } -//===----------------------------------------------------------------------===// -// SparseTensorDescriptor and helpers, manage the sparse tensor memory layout -// scheme. -// -// Sparse tensor storage scheme for rank-dimensional tensor is organized -// as a single compound type with the following fields. Note that every -// memref with ? size actually behaves as a "vector", i.e. the stored -// size is the capacity and the used size resides in the memSizes array. -// -// struct { -// memref dimSizes ; size in each dimension -// memref memSizes ; sizes of ptrs/inds/values -// ; per-dimension d: -// ; if dense: -// -// ; if compresed: -// memref pointers-d ; pointers for sparse dim d -// memref indices-d ; indices for sparse dim d -// ; if singleton: -// memref indices-d ; indices for singleton dim d -// memref values ; values -// }; -// -//===----------------------------------------------------------------------===// -enum class SparseTensorFieldKind { - DimSizes, - MemSizes, - PtrMemRef, - IdxMemRef, - ValMemRef -}; - -constexpr uint64_t dimSizesIdx = 0; -constexpr uint64_t memSizesIdx = dimSizesIdx + 1; -constexpr uint64_t dataFieldIdx = memSizesIdx + 1; - -/// For each field that will be allocated for the given sparse tensor encoding, -/// calls the callback with the corresponding field index, field kind, dimension -/// (for sparse tensor level memrefs) and dimlevelType. -/// The field index always starts with zero and increments by one between two -/// callback invocations. -/// Ideally, all other methods should rely on this function to query a sparse -/// tensor fields instead of relying on ad-hoc index computation. -void foreachFieldInSparseTensor( - SparseTensorEncodingAttr, - llvm::function_ref); - -/// Same as above, except that it also builds the Type for the corresponding -/// field. -void foreachFieldAndTypeInSparseTensor( - RankedTensorType, - llvm::function_ref); - -/// Gets the total number of fields for the given sparse tensor encoding. -unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc); - -/// Gets the total number of data fields (index arrays, pointer arrays, and a -/// value array) for the given sparse tensor encoding. -unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc); - -/// Get the index of the field in memSizes (only valid for data fields). -inline unsigned getFieldMemSizesIndex(unsigned fid) { - assert(fid >= dataFieldIdx); - return fid - dataFieldIdx; -} - -template -struct SparseTensorValueArrayRef; - -// Uses ValueRange for immuatable descriptors; uses SmallVectorImpl & -// for mutable descriptors. -template <> -struct SparseTensorValueArrayRef { - using ValueArray = ValueRange; -}; - -// Using SmallVector for mutable descriptor allows users to reuse it as a tmp -// buffers to append value for some special cases, though users should be -// responsible to restore the buffer to legal states after their use. It is -// probably not a clean way, but it is the most efficient way to avoid copying -// the fields into another SmallVector. If a more clear way is wanted, we -// should change it to MutableArrayRef instead. -template <> -struct SparseTensorValueArrayRef { - using ValueArray = SmallVectorImpl &; -}; - -/// A helper class around an array of values that corresponding to a sparse -/// tensor, provides a set of meaningful APIs to query and update a particular -/// field in a consistent way. -/// Users should not make assumption on how a sparse tensor is laid out but -/// instead relies on this class to access the right value for the right field. -template -class SparseTensorDescriptorImpl { -private: - using Storage = typename SparseTensorValueArrayRef::ValueArray; - -public: - SparseTensorDescriptorImpl(Type tp, Storage fields) - : rType(tp.cast()), fields(fields) { - assert(getSparseTensorEncoding(tp) && - getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) == - fields.size()); - // We should make sure the class is trivially copyable (and should be small - // enough) such that we can pass it by value. - static_assert( - std::is_trivially_copyable_v>); - } - - // Implicit (and cheap) type conversion from MutSparseTensorDescriptor to - // SparseTensorDescriptor. - template > - /*implicit*/ SparseTensorDescriptorImpl(std::enable_if_t &mDesc) - : rType(mDesc.getTensorType()), fields(mDesc.getFields()) {} - - /// - /// Getters: get the field index for required field. - /// - - unsigned getPtrMemRefIndex(unsigned ptrDim) const { - return getFieldIndex(ptrDim, SparseTensorFieldKind::PtrMemRef); - } - - unsigned getIdxMemRefIndex(unsigned idxDim) const { - return getFieldIndex(idxDim, SparseTensorFieldKind::IdxMemRef); - } - - unsigned getValMemRefIndex() const { return fields.size() - 1; } - - unsigned getPtrMemSizesIndex(unsigned dim) const { - return getPtrMemRefIndex(dim) - dataFieldIdx; - } - - unsigned getIdxMemSizesIndex(unsigned dim) const { - return getIdxMemRefIndex(dim) - dataFieldIdx; - } - - unsigned getValMemSizesIndex() const { - return getValMemRefIndex() - dataFieldIdx; - } - - unsigned getNumFields() const { return fields.size(); } - - /// - /// Getters: get the value for required field. - /// - - Value getDimSizesMemRef() const { return fields[dimSizesIdx]; } - Value getMemSizesMemRef() const { return fields[memSizesIdx]; } - - Value getPtrMemRef(unsigned ptrDim) const { - return fields[getPtrMemRefIndex(ptrDim)]; - } - - Value getIdxMemRef(unsigned idxDim) const { - return fields[getIdxMemRefIndex(idxDim)]; - } - - Value getValMemRef() const { return fields[getValMemRefIndex()]; } - - Value getField(unsigned fid) const { - assert(fid < fields.size()); - return fields[fid]; - } - - /// - /// Setters: update the value for required field (only enabled for - /// MutSparseTensorDescriptor). - /// - - template - void setField(unsigned fid, std::enable_if_t v) { - assert(fid < fields.size()); - fields[fid] = v; - } - - RankedTensorType getTensorType() const { return rType; } - Storage getFields() const { return fields; } - - Type getElementType(unsigned fidx) const { - return fields[fidx].getType().template cast().getElementType(); - } - -private: - unsigned getFieldIndex(unsigned dim, SparseTensorFieldKind kind) const { - unsigned fieldIdx = -1u; - foreachFieldInSparseTensor( - getSparseTensorEncoding(rType), - [dim, kind, &fieldIdx](unsigned fIdx, SparseTensorFieldKind fKind, - unsigned fDim, DimLevelType dlt) -> bool { - if (fDim == dim && kind == fKind) { - fieldIdx = fIdx; - // Returns false to break the iteration. - return false; - } - return true; - }); - assert(fieldIdx != -1u); - return fieldIdx; - } - - RankedTensorType rType; - Storage fields; -}; - -using SparseTensorDescriptor = SparseTensorDescriptorImpl; -using MutSparseTensorDescriptor = SparseTensorDescriptorImpl; - //===----------------------------------------------------------------------===// // SparseTensorLoopEmiter class, manages sparse tensors and helps to // generate loop structure to (co)-iterate sparse tensors. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -90,116 +90,6 @@ return val; } -void sparse_tensor::foreachFieldInSparseTensor( - const SparseTensorEncodingAttr enc, - llvm::function_ref - callback) { - assert(enc); - -#define RETURN_ON_FALSE(idx, kind, dim, dlt) \ - if (!(callback(idx, kind, dim, dlt))) \ - return; - - RETURN_ON_FALSE(dimSizesIdx, SparseTensorFieldKind::DimSizes, -1u, - DimLevelType::Undef); - RETURN_ON_FALSE(memSizesIdx, SparseTensorFieldKind::MemSizes, -1u, - DimLevelType::Undef); - - static_assert(dataFieldIdx == memSizesIdx + 1); - unsigned fieldIdx = dataFieldIdx; - // Per-dimension storage. - for (unsigned r = 0, rank = enc.getDimLevelType().size(); r < rank; r++) { - // Dimension level types apply in order to the reordered dimension. - // As a result, the compound type can be constructed directly in the given - // order. - auto dlt = getDimLevelType(enc, r); - if (isCompressedDLT(dlt)) { - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt); - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); - } else if (isSingletonDLT(dlt)) { - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); - } else { - assert(isDenseDLT(dlt)); // no fields - } - } - // The values array. - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u, - DimLevelType::Undef); - -#undef RETURN_ON_FALSE -} - -void sparse_tensor::foreachFieldAndTypeInSparseTensor( - RankedTensorType rType, - llvm::function_ref - callback) { - auto enc = getSparseTensorEncoding(rType); - assert(enc); - // Construct the basic types. - Type indexType = IndexType::get(enc.getContext()); - Type idxType = enc.getIndexType(); - Type ptrType = enc.getPointerType(); - Type eltType = rType.getElementType(); - unsigned rank = rType.getShape().size(); - // memref dimSizes - Type dimSizeType = MemRefType::get({rank}, indexType); - // memref memSizes - Type memSizeType = - MemRefType::get({getNumDataFieldsFromEncoding(enc)}, indexType); - // memref pointers - Type ptrMemType = MemRefType::get({ShapedType::kDynamic}, ptrType); - // memref indices - Type idxMemType = MemRefType::get({ShapedType::kDynamic}, idxType); - // memref values - Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType); - - foreachFieldInSparseTensor( - enc, - [dimSizeType, memSizeType, ptrMemType, idxMemType, valMemType, - callback](unsigned fieldIdx, SparseTensorFieldKind fieldKind, - unsigned dim, DimLevelType dlt) -> bool { - switch (fieldKind) { - case SparseTensorFieldKind::DimSizes: - return callback(dimSizeType, fieldIdx, fieldKind, dim, dlt); - case SparseTensorFieldKind::MemSizes: - return callback(memSizeType, fieldIdx, fieldKind, dim, dlt); - case SparseTensorFieldKind::PtrMemRef: - return callback(ptrMemType, fieldIdx, fieldKind, dim, dlt); - case SparseTensorFieldKind::IdxMemRef: - return callback(idxMemType, fieldIdx, fieldKind, dim, dlt); - case SparseTensorFieldKind::ValMemRef: - return callback(valMemType, fieldIdx, fieldKind, dim, dlt); - }; - llvm_unreachable("unrecognized field kind"); - }); -} - -unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { - unsigned numFields = 0; - foreachFieldInSparseTensor(enc, - [&numFields](unsigned, SparseTensorFieldKind, - unsigned, DimLevelType) -> bool { - numFields++; - return true; - }); - return numFields; -} - -unsigned -sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) { - unsigned numFields = 0; // one value memref - foreachFieldInSparseTensor(enc, - [&numFields](unsigned fidx, SparseTensorFieldKind, - unsigned, DimLevelType) -> bool { - if (fidx >= dataFieldIdx) - numFields++; - return true; - }); - assert(numFields == getNumFieldsFromEncoding(enc) - dataFieldIdx); - return numFields; -} //===----------------------------------------------------------------------===// // Sparse tensor loop emitter class implementations //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -331,7 +331,7 @@ Location loc = func.getLoc(); ValueRange args = entryBlock->getArguments(); Value p = args[hiIdx]; - SmallVector types(2, p.getType()); // only two + SmallVector types(2, p.getType()); // only two scf::WhileOp whileOp = builder.create( loc, types, SmallVector{args[loIdx], args[hiIdx]}); @@ -490,7 +490,7 @@ Value i = lo; Value j = builder.create(loc, hi, c1); - SmallVector operands{i, j, p}; // exactly three + SmallVector operands{i, j, p}; // exactly three SmallVector types{i.getType(), j.getType(), p.getType()}; scf::WhileOp whileOp = builder.create(loc, types, operands); @@ -770,9 +770,7 @@ Value c0 = constantIndex(rewriter, loc, 0); Value buffer = op.getInBuffer(); Value capacity = rewriter.create(loc, buffer, c0); - Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue()); - Value bufferSizes = op.getBufferSizes(); - Value size = rewriter.create(loc, bufferSizes, idx); + Value size = op.getCurSize(); Value value = op.getValue(); Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1); @@ -852,8 +850,7 @@ } // Update the buffer size. - rewriter.create(loc, newSize, bufferSizes, idx); - rewriter.replaceOp(op, buffer); + rewriter.replaceOp(op, {buffer, newSize}); return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpecifierToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpecifierToLLVM.cpp @@ -0,0 +1,184 @@ +//===- SparseVectorization.cpp - Vectorization of sparsified loops --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "CodegenUtils.h" +#include "SparseTensorBuilder.h" + +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" + +using namespace mlir; +using namespace sparse_tensor; + +static SmallVector getSpecifierFields(StorageSpecifierType tp) { + MLIRContext *ctx = tp.getContext(); + auto enc = tp.getEncoding(); + unsigned rank = enc.getDimLevelType().size(); + + SmallVector result; + auto indexType = tp.getSizesType(); + auto dimSizes = LLVM::LLVMArrayType::get(ctx, indexType, rank); + auto memSizes = LLVM::LLVMArrayType::get(ctx, indexType, + getNumDataFieldsFromEncoding(enc)); + result.push_back(dimSizes); + result.push_back(memSizes); + return result; +} + +static Type convertSpecifier(StorageSpecifierType tp) { + return LLVM::LLVMStructType::getLiteral(tp.getContext(), + getSpecifierFields(tp)); +} + +SparseSpecifierToLLVMTypeConverter::SparseSpecifierToLLVMTypeConverter() { + addConversion([](Type type) { return type; }); + addConversion([](StorageSpecifierType tp) { return convertSpecifier(tp); }); +} + +constexpr uint64_t kDimSizePosInSpecifier = 0; +constexpr uint64_t kMemSizePosInSpecifier = 1; + +class SpecifierStructBuilder : public StructBuilder { +public: + explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) { + assert(value); + } + + // Undef value for dimension sizes, all zero value for memory sizes. + static Value getInitValue(OpBuilder &builder, Location loc, Type structType); + + Value dimSize(OpBuilder &builder, Location loc, unsigned dim); + void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value size); + + Value memSize(OpBuilder &builder, Location loc, unsigned pos); + void setMemSize(OpBuilder &builder, Location loc, unsigned pos, Value size); +}; + +Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc, + Type structType) { + Value metaData = builder.create(loc, structType); + SpecifierStructBuilder md(metaData); + auto memSizeArrayType = structType.cast() + .getBody()[kMemSizePosInSpecifier] + .cast(); + + Value zero = constantZero(builder, loc, memSizeArrayType.getElementType()); + // Fill memSizes array with zero. + + for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++) + md.setMemSize(builder, loc, i, zero); + + return md; +} + +/// Builds IR inserting the pos-th size into the descriptor. +Value SpecifierStructBuilder::dimSize(OpBuilder &builder, Location loc, + unsigned dim) { + return builder.create( + loc, value, ArrayRef({kDimSizePosInSpecifier, dim})); +} + +/// Builds IR inserting the pos-th size into the descriptor. +void SpecifierStructBuilder::setDimSize(OpBuilder &builder, Location loc, + unsigned dim, Value size) { + value = builder.create( + loc, value, size, ArrayRef({kDimSizePosInSpecifier, dim})); +} + +/// Builds IR extracting the pos-th memory size into the descriptor. +Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc, + unsigned pos) { + return builder.create( + loc, value, ArrayRef({kMemSizePosInSpecifier, pos})); +} + +/// Builds IR inserting the pos-th memory size into the descriptor. +void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc, + unsigned pos, Value size) { + value = builder.create( + loc, value, size, ArrayRef({kMemSizePosInSpecifier, pos})); +} + +template +class SpecifierGetterSetterOpConverter : public OpConversionPattern { +public: + using OpAdaptor = typename SourceOp::Adaptor; + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SpecifierStructBuilder spec(adaptor.getSpecifier()); + Value v; + if (op.getSpecifierKind() == StorageSpecifierKind::DimSize) { + v = Base::onDimSize(rewriter, op, spec, + op.getDim().value().getZExtValue()); + } else { + auto enc = op.getSpecifier().getType().getEncoding(); + StorageLayout layout(enc); + Optional dim = std::nullopt; + if (op.getDim()) + dim = op.getDim().value().getZExtValue(); + unsigned idx = layout.getMemRefFieldIndex(op.getSpecifierKind(), dim); + v = Base::onMemSize(rewriter, op, spec, idx); + } + + rewriter.replaceOp(op, v); + return success(); + } +}; + +struct SpecifierSetOpConverter + : public SpecifierGetterSetterOpConverter { + using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter; + static Value onDimSize(OpBuilder &builder, SetStorageSpecifierOp op, + SpecifierStructBuilder &spec, unsigned d) { + spec.setDimSize(builder, op.getLoc(), d, op.getValue()); + return spec; + } + + static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op, + SpecifierStructBuilder &spec, unsigned i) { + spec.setMemSize(builder, op.getLoc(), i, op.getValue()); + return spec; + } +}; + +struct SpecifierGetOpConverter + : public SpecifierGetterSetterOpConverter { + using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter; + static Value onDimSize(OpBuilder &builder, GetStorageSpecifierOp op, + SpecifierStructBuilder &spec, unsigned d) { + return spec.dimSize(builder, op.getLoc(), d); + } + static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op, + SpecifierStructBuilder &spec, unsigned i) { + return spec.memSize(builder, op.getLoc(), i); + } +}; + +struct SpecifierInitOpConverter + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type llvmType = getTypeConverter()->convertType(op.getResult().getType()); + rewriter.replaceOp(op, SpecifierStructBuilder::getInitValue( + rewriter, op.getLoc(), llvmType)); + return success(); + } +}; + +void mlir::populateSparseSpecifierToLLVMPatterns(TypeConverter &converter, + RewritePatternSet &patterns) { + patterns.add(converter, patterns.getContext()); +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.h @@ -0,0 +1,358 @@ +//===- SparseTensorBuilder.h ------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines utilities for lowering and access sparse tensor +// types. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_ +#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_ + +#include "mlir/Conversion/LLVMCommon/StructBuilder.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace sparse_tensor { + +//===----------------------------------------------------------------------===// +// SparseTensorDescriptor and helpers, manage the sparse tensor memory layout +// scheme. +// +// Sparse tensor storage scheme for rank-dimensional tensor is organized +// as a single compound type with the following fields. Note that every +// memref with ? size actually behaves as a "vector", i.e. the stored +// size is the capacity and the used size resides in the memSizes array. +// +// struct { +// ; per-dimension d: +// ; if dense: +// +// ; if compresed: +// memref pointers-d ; pointers for sparse dim d +// memref indices-d ; indices for sparse dim d +// ; if singleton: +// memref indices-d ; indices for singleton dim d +// memref values ; values +// +// ; sparse tensor metadata +// struct { +// array dimSizes ; sizes for each dimension +// array memSizes; ; sizes for each data memref +// } +// }; +// +//===----------------------------------------------------------------------===// +enum class SparseTensorFieldKind : uint32_t { + StorageSpec = 0, + PtrMemRef = 1, + IdxMemRef = 2, + ValMemRef = 3 +}; + +static_assert(static_cast(SparseTensorFieldKind::PtrMemRef) == + static_cast(StorageSpecifierKind::PtrMemSize)); +static_assert(static_cast(SparseTensorFieldKind::IdxMemRef) == + static_cast(StorageSpecifierKind::IdxMemSize)); +static_assert(static_cast(SparseTensorFieldKind::ValMemRef) == + static_cast(StorageSpecifierKind::ValMemSize)); + +/// For each field that will be allocated for the given sparse tensor encoding, +/// calls the callback with the corresponding field index, field kind, dimension +/// (for sparse tensor level memrefs) and dimlevelType. +/// The field index always starts with zero and increments by one between two +/// callback invocations. +/// Ideally, all other methods should rely on this function to query a sparse +/// tensor fields instead of relying on ad-hoc index computation. +void foreachFieldInSparseTensor( + SparseTensorEncodingAttr, + llvm::function_ref); + +/// Same as above, except that it also builds the Type for the corresponding +/// field. +void foreachFieldAndTypeInSparseTensor( + RankedTensorType, + llvm::function_ref); + +/// Gets the total number of fields for the given sparse tensor encoding. +unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc); + +/// Gets the total number of data fields (index arrays, pointer arrays, and a +/// value array) for the given sparse tensor encoding. +unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc); + +inline StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind) { + assert(kind != SparseTensorFieldKind::StorageSpec); + return static_cast(kind); +} + +inline SparseTensorFieldKind toFieldKind(StorageSpecifierKind kind) { + assert(kind != StorageSpecifierKind::DimSize); + return static_cast(kind); +} + +class StorageLayout { +public: + explicit StorageLayout(SparseTensorEncodingAttr enc) : enc(enc) {} + + /// + /// Getters: get the field index for required field. + /// + unsigned getMemRefFieldIndex(SparseTensorFieldKind kind, + Optional dim) const; + + unsigned getMemRefFieldIndex(StorageSpecifierKind kind, + Optional dim) const; + +private: + unsigned getFieldIndex(unsigned dim, SparseTensorFieldKind kind) const; + SparseTensorEncodingAttr enc; +}; + +class SparseTensorSpecifier { +public: + explicit SparseTensorSpecifier(Value specifier) : specifier(specifier) {} + + // Undef value for dimension sizes, all zero value for memory sizes. + static Value getInitValue(OpBuilder &builder, Location loc, + RankedTensorType rtp); + + /*implicit*/ operator Value() { return specifier; } + + Value getSpecifierField(OpBuilder &builder, Location loc, + StorageSpecifierKind kind, Optional dim); + + void setSpecifierField(OpBuilder &builder, Location loc, Value v, + StorageSpecifierKind kind, Optional dim); + + Type getFieldType(StorageSpecifierKind kind, Optional dim) { + return specifier.getType().getFieldType(kind, dim); + } + +private: + TypedValue specifier; +}; + +/// A helper class around an array of values that corresponding to a sparse +/// tensor, provides a set of meaningful APIs to query and update a particular +/// field in a consistent way. +/// Users should not make assumption on how a sparse tensor is laid out but +/// instead relies on this class to access the right value for the right field. +template +class SparseTensorDescriptorImpl { +private: + // Uses ValueRange for immuatable descriptors; uses SmallVectorImpl & + // for mutable descriptors. + // Using SmallVector for mutable descriptor allows users to reuse it as a tmp + // buffers to append value for some special cases, though users should be + // responsible to restore the buffer to legal states after their use. It is + // probably not a clean way, but it is the most efficient way to avoid copying + // the fields into another SmallVector. If a more clear way is wanted, we + // should change it to MutableArrayRef instead. + using ValueArrayRef = typename std::conditional &, + ValueRange>::type; + +public: + SparseTensorDescriptorImpl(Type tp, ValueArrayRef fields) + : rType(tp.cast()), fields(fields) { + assert(getSparseTensorEncoding(tp) && + getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) == + fields.size()); + // We should make sure the class is trivially copyable (and should be small + // enough) such that we can pass it by value. + static_assert( + std::is_trivially_copyable_v>); + } + + // Implicit (and cheap) type conversion from MutSparseTensorDescriptor to + // SparseTensorDescriptor. + template > + /*implicit*/ SparseTensorDescriptorImpl(std::enable_if_t &mDesc) + : rType(mDesc.getTensorType()), fields(mDesc.getFields()) {} + + unsigned getMemRefFieldIndex(SparseTensorFieldKind kind, + Optional dim) const { + // Delegates to storage layout. + StorageLayout layout(getSparseTensorEncoding(rType)); + return layout.getMemRefFieldIndex(kind, dim); + } + + unsigned getPtrMemRefIndex(unsigned ptrDim) const { + return getMemRefFieldIndex(SparseTensorFieldKind::PtrMemRef, ptrDim); + } + + unsigned getIdxMemRefIndex(unsigned idxDim) const { + return getMemRefFieldIndex(SparseTensorFieldKind::IdxMemRef, idxDim); + } + + unsigned getValMemRefIndex() const { + return getMemRefFieldIndex(SparseTensorFieldKind::ValMemRef, std::nullopt); + } + + unsigned getNumFields() const { return fields.size(); } + + /// + /// Getters: get the value for required field. + /// + + Value getSpecifierField(OpBuilder &builder, Location loc, + StorageSpecifierKind kind, + Optional dim) const { + SparseTensorSpecifier md(fields.back()); + return md.getSpecifierField(builder, loc, kind, dim); + } + + Value getDimSize(OpBuilder &builder, Location loc, unsigned dim) const { + return getSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim); + } + + Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const { + return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize, + dim); + } + + Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const { + return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize, + dim); + } + + Value getValMemSize(OpBuilder &builder, Location loc) const { + return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize, + std::nullopt); + } + + Value getPtrMemRef(unsigned ptrDim) const { + return getMemRefField(SparseTensorFieldKind::PtrMemRef, ptrDim); + } + + Value getIdxMemRef(unsigned idxDim) const { + return getMemRefField(SparseTensorFieldKind::IdxMemRef, idxDim); + } + + Value getValMemRef() const { + return getMemRefField(SparseTensorFieldKind::ValMemRef, std::nullopt); + } + + Value getMemRefField(SparseTensorFieldKind kind, + Optional dim) const { + return fields[getMemRefFieldIndex(kind, dim)]; + } + + Value getMemRefField(unsigned fidx) const { + assert(fidx < fields.size() - 1); + return fields[fidx]; + } + + Value getField(unsigned fidx) const { + assert(fidx < fields.size()); + return fields[fidx]; + } + + /// + /// Setters: update the value for required field (only enabled for + /// MutSparseTensorDescriptor). + /// + + template + void setMemRefField(SparseTensorFieldKind kind, Optional dim, + std::enable_if_t v) { + fields[getMemRefFieldIndex(kind, dim)] = v; + } + + template + void setMemRefField(unsigned fidx, std::enable_if_t v) { + assert(fidx < fields.size() - 1); + fields[fidx] = v; + } + + template + void setField(unsigned fidx, std::enable_if_t v) { + assert(fidx < fields.size()); + fields[fidx] = v; + } + + template + void setSpecifierField(OpBuilder &builder, Location loc, + StorageSpecifierKind kind, Optional dim, + std::enable_if_t v) { + SparseTensorSpecifier md(fields.back()); + md.setSpecifierField(builder, loc, v, kind, dim); + fields.back() = md; + } + + template + void setDimSize(OpBuilder &builder, Location loc, unsigned dim, + std::enable_if_t v) { + setSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim, v); + } + + ValueRange getMemRefFields() const { + ValueRange ret = fields; + // drop the last metadata fields + return ret.slice(0, fields.size() - 1); + } + + Type getMemRefElementType(SparseTensorFieldKind kind, + Optional dim) const { + return getMemRefField(kind, dim) + .getType() + .template cast() + .getElementType(); + } + + RankedTensorType getTensorType() const { return rType; } + ValueArrayRef getFields() const { return fields; } + +private: + RankedTensorType rType; + ValueArrayRef fields; +}; + +using SparseTensorDescriptor = SparseTensorDescriptorImpl; +using MutSparseTensorDescriptor = SparseTensorDescriptorImpl; + +/// Returns the "tuple" value of the adapted tensor. +inline UnrealizedConversionCastOp getTuple(Value tensor) { + return llvm::cast(tensor.getDefiningOp()); +} + +/// Packs the given values as a "tuple" value. +inline Value genTuple(OpBuilder &builder, Location loc, Type tp, + ValueRange values) { + return builder.create(loc, TypeRange(tp), values) + .getResult(0); +} + +inline Value genTuple(OpBuilder &builder, Location loc, + SparseTensorDescriptor desc) { + return genTuple(builder, loc, desc.getTensorType(), desc.getFields()); +} + +inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) { + auto tuple = getTuple(tensor); + return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs()); +} + +inline MutSparseTensorDescriptor +getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl &fields) { + auto tuple = getTuple(tensor); + fields.assign(tuple.getInputs().begin(), tuple.getInputs().end()); + return MutSparseTensorDescriptor(tuple.getResultTypes()[0], fields); +} + +} // namespace sparse_tensor +} // namespace mlir +#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_ diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.cpp @@ -0,0 +1,221 @@ +//===- SparseTensorBuilder.cpp --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "SparseTensorBuilder.h" +#include "CodegenUtils.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace sparse_tensor; + +static Value createIndexCast(OpBuilder &builder, Location loc, Value value, + Type to) { + if (value.getType() != to) + return builder.create(loc, to, value); + return value; +} + +static IntegerAttr fromOptionalInt(MLIRContext *ctx, Optional dim) { + if (!dim) + return nullptr; + return IntegerAttr::get(IndexType::get(ctx), dim.value()); +} + +Optional +SparseTensorTypeToBufferConverter::convertSparseTensorType( + RankedTensorType rtp, SmallVectorImpl &fields) { + auto enc = getSparseTensorEncoding(rtp); + if (!enc) + return std::nullopt; + + foreachFieldAndTypeInSparseTensor( + rtp, + [&fields](Type fieldType, unsigned fieldIdx, + SparseTensorFieldKind /*fieldKind*/, unsigned /*dim*/, + DimLevelType /*dlt*/) -> bool { + assert(fieldIdx == fields.size()); + fields.push_back(fieldType); + return true; + }); + return success(); +} + +SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { + addConversion([](Type type) { return type; }); + addConversion([&](RankedTensorType rtp, SmallVectorImpl &fields) { + return convertSparseTensorType(rtp, fields); + }); + + // Required by scf.for 1:N type conversion. + addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp, + ValueRange inputs, + Location loc) -> Optional { + if (!getSparseTensorEncoding(tp)) + // Not a sparse tensor. + return std::nullopt; + // Sparse compiler knows how to cancel out these casts. + return genTuple(builder, loc, tp, inputs); + }); +} +unsigned StorageLayout::getMemRefFieldIndex(SparseTensorFieldKind kind, + Optional dim) const { + unsigned fieldIdx = -1u; + foreachFieldInSparseTensor( + enc, + [dim, kind, &fieldIdx](unsigned fIdx, SparseTensorFieldKind fKind, + unsigned fDim, DimLevelType dlt) -> bool { + if ((dim && fDim == dim.value() && kind == fKind) || + (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) { + fieldIdx = fIdx; + // Returns false to break the iteration. + return false; + } + return true; + }); + assert(fieldIdx != -1u); + return fieldIdx; +} + +unsigned StorageLayout::getMemRefFieldIndex(StorageSpecifierKind kind, + Optional dim) const { + return getMemRefFieldIndex(toFieldKind(kind), dim); +} + +Value SparseTensorSpecifier::getInitValue(OpBuilder &builder, Location loc, + RankedTensorType rtp) { + return builder.create( + loc, StorageSpecifierType::get(getSparseTensorEncoding(rtp))); +} + +Value SparseTensorSpecifier::getSpecifierField(OpBuilder &builder, Location loc, + StorageSpecifierKind kind, + Optional dim) { + return createIndexCast(builder, loc, + builder.create( + loc, getFieldType(kind, dim), specifier, kind, + fromOptionalInt(specifier.getContext(), dim)), + builder.getIndexType()); +} + +void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc, + Value v, + StorageSpecifierKind kind, + Optional dim) { + specifier = builder.create( + loc, specifier, kind, fromOptionalInt(specifier.getContext(), dim), + createIndexCast(builder, loc, v, getFieldType(kind, dim))); +} + +constexpr uint64_t kDataFieldStartingIdx = 0; + +void sparse_tensor::foreachFieldInSparseTensor( + const SparseTensorEncodingAttr enc, + llvm::function_ref + callback) { + assert(enc); + +#define RETURN_ON_FALSE(idx, kind, dim, dlt) \ + if (!(callback(idx, kind, dim, dlt))) \ + return; + + static_assert(kDataFieldStartingIdx == 0); + unsigned fieldIdx = kDataFieldStartingIdx; + // Per-dimension storage. + for (unsigned r = 0, rank = enc.getDimLevelType().size(); r < rank; r++) { + // Dimension level types apply in order to the reordered dimension. + // As a result, the compound type can be constructed directly in the given + // order. + auto dlt = getDimLevelType(enc, r); + if (isCompressedDLT(dlt)) { + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt); + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); + } else if (isSingletonDLT(dlt)) { + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); + } else { + assert(isDenseDLT(dlt)); // no fields + } + } + // The values array. + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u, + DimLevelType::Undef); + + // Put metadata at the end. + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::StorageSpec, -1u, + DimLevelType::Undef); + +#undef RETURN_ON_FALSE +} + +void sparse_tensor::foreachFieldAndTypeInSparseTensor( + RankedTensorType rType, + llvm::function_ref + callback) { + auto enc = getSparseTensorEncoding(rType); + assert(enc); + // Construct the basic types. + Type idxType = enc.getIndexType(); + Type ptrType = enc.getPointerType(); + Type eltType = rType.getElementType(); + + Type metaDataType = StorageSpecifierType::get(enc); + // memref pointers + Type ptrMemType = MemRefType::get({ShapedType::kDynamic}, ptrType); + // memref indices + Type idxMemType = MemRefType::get({ShapedType::kDynamic}, idxType); + // memref values + Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType); + + foreachFieldInSparseTensor( + enc, + [metaDataType, ptrMemType, idxMemType, valMemType, + callback](unsigned fieldIdx, SparseTensorFieldKind fieldKind, + unsigned dim, DimLevelType dlt) -> bool { + switch (fieldKind) { + case SparseTensorFieldKind::StorageSpec: + return callback(metaDataType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::PtrMemRef: + return callback(ptrMemType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::IdxMemRef: + return callback(idxMemType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::ValMemRef: + return callback(valMemType, fieldIdx, fieldKind, dim, dlt); + }; + llvm_unreachable("unrecognized field kind"); + }); +} + +unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { + unsigned numFields = 0; + foreachFieldInSparseTensor(enc, + [&numFields](unsigned, SparseTensorFieldKind, + unsigned, DimLevelType) -> bool { + numFields++; + return true; + }); + return numFields; +} + +unsigned +sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) { + unsigned numFields = 0; // one value memref + foreachFieldInSparseTensor(enc, + [&numFields](unsigned fidx, SparseTensorFieldKind, + unsigned, DimLevelType) -> bool { + if (fidx >= kDataFieldStartingIdx) + numFields++; + return true; + }); + numFields -= 1; // the last field is MetaData field + assert(numFields == + getNumFieldsFromEncoding(enc) - kDataFieldStartingIdx - 1); + return numFields; +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -16,6 +16,7 @@ //===----------------------------------------------------------------------===// #include "CodegenUtils.h" +#include "SparseTensorBuilder.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -40,38 +41,6 @@ // Helper methods. //===----------------------------------------------------------------------===// -/// Returns the "tuple" value of the adapted tensor. -static UnrealizedConversionCastOp getTuple(Value tensor) { - return llvm::cast(tensor.getDefiningOp()); -} - -static SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) { - auto tuple = getTuple(tensor); - return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs()); -} - -static MutSparseTensorDescriptor -getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl &fields) { - auto tuple = getTuple(tensor); - fields.assign(tuple.getInputs().begin(), tuple.getInputs().end()); - return MutSparseTensorDescriptor(tuple.getResultTypes()[0], fields); -} - -/// Packs the given values as a "tuple" value. -static Value genTuple(OpBuilder &builder, Location loc, Type tp, - ValueRange values) { - return builder.create(loc, TypeRange(tp), values) - .getResult(0); -} - -static Value genTuple(OpBuilder &builder, Location loc, - SparseTensorDescriptor desc) { - return builder - .create(loc, desc.getTensorType(), - desc.getFields()) - .getResult(0); -} - /// Flatten a list of operands that may contain sparse tensors. static void flattenOperands(ValueRange operands, SmallVectorImpl &flattened) { @@ -145,9 +114,7 @@ // Any other query can consult the dimSizes array at field DimSizesIdx, // accounting for the reordering applied to the sparse storage. - Value idx = constantIndex(builder, loc, toStoredDim(rtp, dim)); - return builder.create(loc, desc.getDimSizesMemRef(), idx) - .getResult(); + return desc.getDimSize(builder, loc, toStoredDim(rtp, dim)); } // Gets the dimension size at the given stored dimension 'd', either as a @@ -160,40 +127,24 @@ if (!ShapedType::isDynamic(shape[dim])) return constantIndex(builder, loc, shape[dim]); - return genLoad(builder, loc, desc.getDimSizesMemRef(), - constantIndex(builder, loc, d)); + return desc.getDimSize(builder, loc, d); } static void createPushback(OpBuilder &builder, Location loc, - MutSparseTensorDescriptor desc, unsigned fidx, + MutSparseTensorDescriptor desc, + SparseTensorFieldKind kind, Optional dim, Value value, Value repeat = Value()) { - Type etp = desc.getElementType(fidx); - Value field = desc.getField(fidx); - Value newField = builder.create( - loc, field.getType(), desc.getMemSizesMemRef(), field, - toType(builder, loc, value, etp), APInt(64, getFieldMemSizesIndex(fidx)), - repeat); - desc.setField(fidx, newField); -} + Type etp = desc.getMemRefElementType(kind, dim); + Value field = desc.getMemRefField(kind, dim); + StorageSpecifierKind specFieldKind = toSpecifierKind(kind); -/// Maps a sparse tensor type to the appropriate compounded buffers. -static Optional -convertSparseTensorType(Type type, SmallVectorImpl &fields) { - auto enc = getSparseTensorEncoding(type); - if (!enc) - return std::nullopt; + auto pushBackOp = builder.create( + loc, desc.getSpecifierField(builder, loc, specFieldKind, dim), field, + toType(builder, loc, value, etp), repeat); - RankedTensorType rType = type.cast(); - foreachFieldAndTypeInSparseTensor( - rType, - [&fields](Type fieldType, unsigned fieldIdx, - SparseTensorFieldKind /*fieldKind*/, unsigned /*dim*/, - DimLevelType /*dlt*/) -> bool { - assert(fieldIdx == fields.size()); - fields.push_back(fieldType); - return true; - }); - return success(); + desc.setMemRefField(kind, dim, pushBackOp.getOutBuffer()); + desc.setSpecifierField(builder, loc, specFieldKind, dim, + pushBackOp.getNewSize()); } /// Generates code that allocates a sparse storage scheme for given rank. @@ -209,8 +160,8 @@ // the desired "linear + 1" length property at all times. Type ptrType = getSparseTensorEncoding(rtp).getPointerType(); Value ptrZero = constantZero(builder, loc, ptrType); - createPushback(builder, loc, desc, desc.getPtrMemRefIndex(r), ptrZero, - linear); + createPushback(builder, loc, desc, SparseTensorFieldKind::PtrMemRef, r, + ptrZero, linear); return; } if (isSingletonDim(rtp, r)) { @@ -225,7 +176,8 @@ } // Reached values array so prepare for an insertion. Value valZero = constantZero(builder, loc, rtp.getElementType()); - createPushback(builder, loc, desc, desc.getValMemRefIndex(), valZero, linear); + createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef, + std::nullopt, valZero, linear); } /// Creates allocation operation. @@ -256,22 +208,20 @@ foreachFieldAndTypeInSparseTensor( rtp, - [&builder, &fields, loc, heuristic, + [&builder, &fields, rtp, loc, heuristic, enableInit](Type fType, unsigned fIdx, SparseTensorFieldKind fKind, unsigned /*dim*/, DimLevelType /*dlt*/) -> bool { assert(fields.size() == fIdx); - auto memRefTp = fType.cast(); Value field; switch (fKind) { - case SparseTensorFieldKind::DimSizes: - case SparseTensorFieldKind::MemSizes: - field = builder.create(loc, memRefTp); + case SparseTensorFieldKind::StorageSpec: + field = SparseTensorSpecifier::getInitValue(builder, loc, rtp); break; case SparseTensorFieldKind::PtrMemRef: case SparseTensorFieldKind::IdxMemRef: case SparseTensorFieldKind::ValMemRef: - field = - createAllocation(builder, loc, memRefTp, heuristic, enableInit); + field = createAllocation(builder, loc, fType.cast(), + heuristic, enableInit); break; } assert(field); @@ -296,21 +246,18 @@ // to all zeros, sets the dimSizes to known values and gives all pointer // fields an initial zero entry, so that it is easier to maintain the // "linear + 1" length property. - builder.create( - loc, constantZero(builder, loc, builder.getIndexType()), - desc.getMemSizesMemRef()); // zero memSizes - Value ptrZero = constantZero(builder, loc, getSparseTensorEncoding(rtp).getPointerType()); for (unsigned r = 0; r < rank; r++) { unsigned ro = toOrigDim(rtp, r); // Fills dim sizes array. - genStore(builder, loc, sizes[ro], desc.getDimSizesMemRef(), - constantIndex(builder, loc, r)); + desc.setDimSize(builder, loc, r, sizes[ro]); // Pushes a leading zero to pointers memref. - if (isCompressedDim(rtp, r)) - createPushback(builder, loc, desc, desc.getPtrMemRefIndex(r), ptrZero); + if (isCompressedDim(rtp, r)) { + createPushback(builder, loc, desc, SparseTensorFieldKind::PtrMemRef, r, + ptrZero); + } } allocSchemeForRank(builder, loc, desc, /*rank=*/0); } @@ -348,10 +295,11 @@ unsigned ptrIndex = desc.getPtrMemRefIndex(d); Value one = constantIndex(builder, loc, 1); Value pp1 = builder.create(loc, pos, one); - Value plo = genLoad(builder, loc, desc.getField(ptrIndex), pos); - Value phi = genLoad(builder, loc, desc.getField(ptrIndex), pp1); - Value psz = constantIndex(builder, loc, getFieldMemSizesIndex(idxIndex)); - Value msz = genLoad(builder, loc, desc.getMemSizesMemRef(), psz); + Value plo = genLoad(builder, loc, desc.getMemRefField(ptrIndex), pos); + Value phi = genLoad(builder, loc, desc.getMemRefField(ptrIndex), pp1); + Value msz = desc.getIdxMemSize(builder, loc, d); + // Value msz = desc.getMemSize(builder, loc, getFieldMemSizesIndex(idxIndex)); + Value phim1 = builder.create( loc, toType(builder, loc, phi, indexType), one); // Conditional expression. @@ -361,14 +309,14 @@ scf::IfOp ifOp1 = builder.create(loc, types, lt, /*else*/ true); types.pop_back(); builder.setInsertionPointToStart(&ifOp1.getThenRegion().front()); - Value crd = genLoad(builder, loc, desc.getField(idxIndex), phim1); + Value crd = genLoad(builder, loc, desc.getMemRefField(idxIndex), phim1); Value eq = builder.create(loc, arith::CmpIPredicate::eq, toType(builder, loc, crd, indexType), indices[d]); builder.create(loc, eq); builder.setInsertionPointToStart(&ifOp1.getElseRegion().front()); if (d > 0) - genStore(builder, loc, msz, desc.getField(ptrIndex), pos); + genStore(builder, loc, msz, desc.getMemRefField(ptrIndex), pos); builder.create(loc, constantI1(builder, loc, false)); builder.setInsertionPointAfter(ifOp1); Value p = ifOp1.getResult(0); @@ -395,8 +343,9 @@ // If !present (changes fields, update next). builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); Value mszp1 = builder.create(loc, msz, one); - genStore(builder, loc, mszp1, desc.getField(ptrIndex), pp1); - createPushback(builder, loc, desc, idxIndex, indices[d]); + genStore(builder, loc, mszp1, desc.getMemRefField(ptrIndex), pp1); + createPushback(builder, loc, desc, SparseTensorFieldKind::IdxMemRef, d, + indices[d]); // Prepare the next dimension "as needed". if ((d + 1) < rank) allocSchemeForRank(builder, loc, desc, d + 1); @@ -458,7 +407,8 @@ // indices[d].push_back(i[d]) // pos[d] = pos[d-1] // - createPushback(builder, loc, desc, desc.getIdxMemRefIndex(d), indices[d]); + createPushback(builder, loc, desc, SparseTensorFieldKind::IdxMemRef, d, + indices[d]); } else { assert(isDenseDim(rtp, d)); // Construct the new position as: @@ -471,7 +421,8 @@ } // Reached the actual value append/insert. if (!isDenseDim(rtp, rank - 1)) - createPushback(builder, loc, desc, desc.getValMemRefIndex(), value); + createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef, + std::nullopt, value); else genStore(builder, loc, value, desc.getValMemRef(), pos); builder.create(loc, fields); @@ -564,8 +515,7 @@ if (d > 0) { Type ptrType = getSparseTensorEncoding(rtp).getPointerType(); Value ptrMemRef = desc.getPtrMemRef(d); - Value mz = constantIndex(builder, loc, desc.getPtrMemSizesIndex(d)); - Value hi = genLoad(builder, loc, desc.getMemSizesMemRef(), mz); + Value hi = desc.getPtrMemSize(builder, loc, d); Value zero = constantIndex(builder, loc, 0); Value one = constantIndex(builder, loc, 1); // Vector of only one, but needed by createFor's prototype. @@ -722,6 +672,7 @@ bool enableInit) : OpConversionPattern(typeConverter, context), enableBufferInitialization(enableInit) {} + LogicalResult matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -760,8 +711,8 @@ // Replace the sparse tensor deallocation with field deallocations. Location loc = op.getLoc(); - auto tuple = getTuple(adaptor.getTensor()); - for (auto input : tuple.getInputs()) + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + for (auto input : desc.getMemRefFields()) // Deallocate every buffer used to store the sparse tensor handler. rewriter.create(loc, input); @@ -1017,36 +968,13 @@ ConversionPatternRewriter &rewriter) const override { // Query memSizes for the actually stored values size. auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); - Value field = - constantIndex(rewriter, op.getLoc(), desc.getValMemSizesIndex()); - rewriter.replaceOpWithNewOp(op, desc.getMemSizesMemRef(), - field); + rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc())); return success(); } }; } // namespace -//===----------------------------------------------------------------------===// -// Sparse tensor type conversion into an actual buffer. -//===----------------------------------------------------------------------===// - -mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { - addConversion([](Type type) { return type; }); - addConversion(convertSparseTensorType); - - // Required by scf.for 1:N type conversion. - addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp, - ValueRange inputs, - Location loc) -> Optional { - if (!getSparseTensorEncoding(tp)) - // Not a sparse tensor. - return std::nullopt; - // Sparse compiler knows how to cancel out these casts. - return genTuple(builder, loc, tp, inputs); - }); -} - //===----------------------------------------------------------------------===// // Public method for populating conversion rules. //===----------------------------------------------------------------------===// @@ -1064,6 +992,7 @@ SparseToValuesConverter, SparseConvertConverter, SparseNumberOfEntriesConverter>(typeConverter, patterns.getContext()); + patterns.add(typeConverter, patterns.getContext(), enableBufferInitialization); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -28,6 +28,7 @@ #define GEN_PASS_DEF_SPARSETENSORCODEGEN #define GEN_PASS_DEF_SPARSEBUFFERREWRITE #define GEN_PASS_DEF_SPARSEVECTORIZATION +#define GEN_PASS_DEF_SPARSESPECIFIERTOLLVM #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" } // namespace mlir @@ -193,9 +194,14 @@ target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); - // All dynamic rules below accept new function, call, return, and various - // tensor and bufferization operations as legal output of the rewriting - // provided that all sparse tensor types have been fully rewritten. + // Storage specifier outlives sparse tensor pipeline. + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + // All dynamic rules below accept new function, call, return, and + // various tensor and bufferization operations as legal output of the + // rewriting provided that all sparse tensor types have been fully + // rewritten. target.addDynamicallyLegalOp([&](func::FuncOp op) { return converter.isSignatureLegal(op.getFunctionType()); }); @@ -271,6 +277,44 @@ } }; +struct SparseSpecifierToLLVMPass + : public impl::SparseSpecifierToLLVMBase { + + SparseSpecifierToLLVMPass() = default; + + void runOnOperation() override { + auto *ctx = &getContext(); + ConversionTarget target(*ctx); + RewritePatternSet patterns(ctx); + SparseSpecifierToLLVMTypeConverter converter; + + // All ops in the sparse dialect must go! + target.addIllegalDialect(); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return converter.isSignatureLegal(op.getFunctionType()); + }); + target.addDynamicallyLegalOp([&](func::CallOp op) { + return converter.isSignatureLegal(op.getCalleeType()); + }); + target.addDynamicallyLegalOp([&](func::ReturnOp op) { + return converter.isLegal(op.getOperandTypes()); + }); + target.addLegalDialect(); + + populateFunctionOpInterfaceTypeConversionPattern(patterns, + converter); + populateCallOpTypeConversionPattern(patterns, converter); + populateBranchOpInterfaceTypeConversionPattern(patterns, converter); + populateReturnOpTypeConversionPattern(patterns, converter); + scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, + target); + populateSparseSpecifierToLLVMPatterns(converter, patterns); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -355,3 +399,7 @@ return std::make_unique( vectorLength, enableVLAVectorization, enableSIMDIndex32); } + +std::unique_ptr mlir::createSparseSpecifierToLLVMPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -147,6 +147,7 @@ } else { pm.addPass(createSparseTensorCodegenPass(enableBufferInitialization)); pm.addPass(createSparseBufferRewritePass(enableBufferInitialization)); + pm.addPass(createSparseSpecifierToLLVMPass()); } if (failed(runPipeline(pm, getOperation()))) return signalPassFailure(); diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir --- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir @@ -1,15 +1,14 @@ // RUN: mlir-opt %s -split-input-file --sparse-buffer-rewrite --canonicalize --cse | FileCheck %s // CHECK-LABEL: func @sparse_push_back( -// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[A:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref, -// CHECK-SAME: %[[C:.*]]: f64) -> memref { +// CHECK-SAME: %[[C:.*]]: f64) -> (memref, index) { // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[P1:.*]] = memref.dim %[[B]], %[[C0]] -// CHECK: %[[S1:.*]] = memref.load %[[A]]{{\[}}%[[C2]]] -// CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[C1]] : index +// CHECK: %[[S2:.*]] = arith.addi %[[A]], %[[C1]] : index // CHECK: %[[T:.*]] = arith.cmpi ugt, %[[S2]], %[[P1]] // CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref) { // CHECK: %[[P2:.*]] = arith.muli %[[P1]], %[[C2]] @@ -18,25 +17,23 @@ // CHECK: } else { // CHECK: scf.yield %[[B]] : memref // CHECK: } -// CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[S1]]] -// CHECK: memref.store %[[S2]], %[[A]]{{\[}}%[[C2]]] -// CHECK: return %[[M]] : memref -func.func @sparse_push_back(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { - %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 - return %0 : memref +// CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[A]]] +// CHECK: return %[[M]], %[[S2]] +func.func @sparse_push_back(%arg0: index, %arg1: memref, %arg2: f64) -> (memref, index) { + %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2 : index, memref, f64 + return %0#0, %0#1 : memref, index } // ----- // CHECK-LABEL: func @sparse_push_back_n( -// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[S1:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref, // CHECK-SAME: %[[C:.*]]: f64, -// CHECK-SAME: %[[D:.*]]: index) -> memref { +// CHECK-SAME: %[[D:.*]]: index) -> (memref, index) { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[P1:.*]] = memref.dim %[[B]], %[[C0]] -// CHECK: %[[S1:.*]] = memref.load %[[A]]{{\[}}%[[C2]]] // CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[D]] : index // CHECK: %[[T:.*]] = arith.cmpi ugt, %[[S2]], %[[P1]] // CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref) { @@ -55,29 +52,25 @@ // CHECK: } // CHECK: %[[S:.*]] = memref.subview %[[M]]{{\[}}%[[S1]]] {{\[}}%[[D]]] [1] // CHECK: linalg.fill ins(%[[C]] : f64) outs(%[[S]] -// CHECK: memref.store %[[S2]], %[[A]]{{\[}}%[[C2]]] -// CHECK: return %[[M]] : memref -func.func @sparse_push_back_n(%arg0: memref, %arg1: memref, %arg2: f64, %arg3: index) -> memref { - %0 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 {idx = 2 : index} : memref, memref, f64, index - return %0 : memref +// CHECK: return %[[M]], %[[S2]] : memref, index +func.func @sparse_push_back_n(%arg0: index, %arg1: memref, %arg2: f64, %arg3: index) -> (memref, index) { + %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 : index, memref, f64, index + return %0#0, %0#1 : memref, index } // ----- // CHECK-LABEL: func @sparse_push_back_inbound( -// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[S1:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref, -// CHECK-SAME: %[[C:.*]]: f64) -> memref { +// CHECK-SAME: %[[C:.*]]: f64) -> (memref, index) { // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[S1:.*]] = memref.load %[[A]]{{\[}}%[[C2]]] // CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[C1]] // CHECK: memref.store %[[C]], %[[B]]{{\[}}%[[S1]]] -// CHECK: memref.store %[[S2]], %[[A]]{{\[}}%[[C2]]] -// CHECK: return %[[B]] : memref -func.func @sparse_push_back_inbound(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { - %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 - return %0 : memref +// CHECK: return %[[B]], %[[S2]] : memref, index +func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref, %arg2: f64) -> (memref, index) { + %0:2 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 : index, memref, f64 + return %0#0, %0#1 : memref, index } // ----- diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -47,31 +47,28 @@ }> // CHECK-LABEL: func @sparse_nop( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>) +// CHECK: return %[[A1]], %[[A2]], %[[A3]], %[[A4]] : +// CHECK-SAME: memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)> func.func @sparse_nop(%arg0: tensor) -> tensor { return %arg0 : tensor } // CHECK-LABEL: func @sparse_nop_multi_ret( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref<1xindex>, -// CHECK-SAME: %[[A6:.*6]]: memref<3xindex>, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: memref, -// CHECK-SAME: %[[A9:.*9]]: memref) -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]] : -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref, -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[A6:.*4]]: memref, +// CHECK-SAME: %[[A7:.*5]]: memref, +// CHECK-SAME: %[[A8:.*6]]: memref, +// CHECK-SAME: %[[A9:.*7]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>) +// CHECK: return %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A6]], %[[A7]], %[[A8]], %[[A9]] : +// CHECK-SAME: memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)> func.func @sparse_nop_multi_ret(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { @@ -79,20 +76,18 @@ } // CHECK-LABEL: func @sparse_nop_call( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref<1xindex>, -// CHECK-SAME: %[[A6:.*6]]: memref<3xindex>, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: memref, -// CHECK-SAME: %[[A9:.*9]]: memref) -// CHECK: %[[T:.*]]:10 = call @sparse_nop_multi_ret(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]]) -// CHECK: return %[[T]]#0, %[[T]]#1, %[[T]]#2, %[[T]]#3, %[[T]]#4, %[[T]]#5, %[[T]]#6, %[[T]]#7, %[[T]]#8, %[[T]]#9 : -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref, -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[A6:.*4]]: memref, +// CHECK-SAME: %[[A7:.*5]]: memref, +// CHECK-SAME: %[[A8:.*6]]: memref, +// CHECK-SAME: %[[A9:.*7]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>) +// CHECK: %[[T:.*]]:8 = call @sparse_nop_multi_ret(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A6]], %[[A7]], %[[A8]], %[[A9]]) +// CHECK: return %[[T]]#0, %[[T]]#1, %[[T]]#2, %[[T]]#3, %[[T]]#4, %[[T]]#5, %[[T]]#6, %[[T]]#7 : +// CHECK-SAME: memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)> func.func @sparse_nop_call(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { @@ -103,68 +98,61 @@ } // CHECK-LABEL: func @sparse_nop_cast( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>) +// CHECK: return %[[A1]], %[[A2]], %[[A3]], %[[A4]] : func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor { %0 = tensor.cast %arg0 : tensor<64xf32, #SparseVector> to tensor return %0 : tensor } // CHECK-LABEL: func @sparse_nop_cast_3d( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref) -// CHECK: return %[[A0]], %[[A1]], %[[A2]] : -// CHECK-SAME: memref<3xindex>, memref<1xindex>, memref +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: !llvm.struct<(array<3 x i64>, array<1 x i64>)>) +// CHECK: return %[[A1]], %[[A2]] : +// CHECK-SAME: memref, !llvm.struct<(array<3 x i64>, array<1 x i64>)> func.func @sparse_nop_cast_3d(%arg0: tensor<10x20x30xf32, #Dense3D>) -> tensor { %0 = tensor.cast %arg0 : tensor<10x20x30xf32, #Dense3D> to tensor return %0 : tensor } // CHECK-LABEL: func @sparse_dense_2d( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref) +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: !llvm.struct<(array<2 x i64>, array<1 x i64>)>) { // CHECK: return func.func @sparse_dense_2d(%arg0: tensor) { return } // CHECK-LABEL: func @sparse_row( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>) { // CHECK: return func.func @sparse_row(%arg0: tensor) { return } // CHECK-LABEL: func @sparse_csr( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>) { // CHECK: return func.func @sparse_csr(%arg0: tensor) { return } // CHECK-LABEL: func @sparse_dcsr( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref) +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: memref, +// CHECK-SAME: %[[A5:.*4]]: memref, +// CHECK-SAME: %[[A6:.*5]]: !llvm.struct<(array<2 x i64>, array<5 x i64>)>) // CHECK: return func.func @sparse_dcsr(%arg0: tensor) { return @@ -175,9 +163,8 @@ // fold using the original static dimension sizes. // // CHECK-LABEL: func @sparse_dense_3d( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref) +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: !llvm.struct<(array<3 x i64>, array<1 x i64>)>) // CHECK: %[[C:.*]] = arith.constant 20 : index // CHECK: return %[[C]] : index func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index { @@ -192,12 +179,11 @@ // since the latter honors the dimOrdering. // // CHECK-LABEL: func @sparse_dense_3d_dyn( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref) -// CHECK: %[[C:.*]] = arith.constant 2 : index -// CHECK: %[[L:.*]] = memref.load %[[A0]][%[[C]]] : memref<3xindex> -// CHECK: return %[[L]] : index +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: !llvm.struct<(array<3 x i64>, array<1 x i64>)>) +// CHECK: %[[A3:.*]] = llvm.extractvalue %[[A2]][0, 2] +// CHECK: %[[A4:.*]] = arith.index_cast %[[A3]] : i64 to index +// CHECK: return %[[A4]] : index func.func @sparse_dense_3d_dyn(%arg0: tensor) -> index { %c = arith.constant 1 : index %0 = tensor.dim %arg0, %c : tensor @@ -205,55 +191,51 @@ } // CHECK-LABEL: func @sparse_pointers_dcsr( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref) -// CHECK: return %[[A4]] : memref +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: memref, +// CHECK-SAME: %[[A5:.*4]]: memref, +// CHECK-SAME: %[[A6:.*5]]: !llvm.struct<(array<2 x i64>, array<5 x i64>)> +// CHECK: return %[[A3]] : memref func.func @sparse_pointers_dcsr(%arg0: tensor) -> memref { %0 = sparse_tensor.pointers %arg0 { dimension = 1 : index } : tensor to memref return %0 : memref } // CHECK-LABEL: func @sparse_indices_dcsr( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref) -// CHECK: return %[[A5]] : memref +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: memref, +// CHECK-SAME: %[[A5:.*4]]: memref, +// CHECK-SAME: %[[A6:.*5]]: !llvm.struct<(array<2 x i64>, array<5 x i64>)> +// CHECK: return %[[A4]] : memref func.func @sparse_indices_dcsr(%arg0: tensor) -> memref { %0 = sparse_tensor.indices %arg0 { dimension = 1 : index } : tensor to memref return %0 : memref } // CHECK-LABEL: func @sparse_values_dcsr( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref) -// CHECK: return %[[A6]] : memref +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: memref, +// CHECK-SAME: %[[A5:.*4]]: memref, +// CHECK-SAME: %[[A6:.*5]]: !llvm.struct<(array<2 x i64>, array<5 x i64>)>) +// CHECK: return %[[A5]] : memref func.func @sparse_values_dcsr(%arg0: tensor) -> memref { %0 = sparse_tensor.values %arg0 : tensor to memref return %0 : memref } // CHECK-LABEL: func @sparse_noe( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[NOE:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex> +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>) +// CHECK: %[[A5:.*]] = llvm.extractvalue %[[A4]][1, 2] : !llvm.struct<(array<1 x i64>, array<3 x i64>)> +// CHECK: %[[NOE:.*]] = arith.index_cast %[[A5]] : i64 to index // CHECK: return %[[NOE]] : index func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index { %0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector> @@ -261,70 +243,71 @@ } // CHECK-LABEL: func @sparse_dealloc_csr( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) -// CHECK: memref.dealloc %[[A0]] : memref<2xindex> -// CHECK: memref.dealloc %[[A1]] : memref<3xindex> -// CHECK: memref.dealloc %[[A2]] : memref -// CHECK: memref.dealloc %[[A3]] : memref -// CHECK: memref.dealloc %[[A4]] : memref +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>) +// CHECK: memref.dealloc %[[A1]] : memref +// CHECK: memref.dealloc %[[A2]] : memref +// CHECK: memref.dealloc %[[A3]] : memref // CHECK: return func.func @sparse_dealloc_csr(%arg0: tensor) { bufferization.dealloc_tensor %arg0 : tensor return } -// CHECK-LABEL: func @sparse_alloc_csc( -// CHECK-SAME: %[[A:.*]]: index) -> -// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref, memref, memref -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index -// CHECK: %[[T0:.*]] = memref.alloc() : memref<2xindex> -// CHECK: %[[T1:.*]] = memref.alloc() : memref<3xindex> -// CHECK: %[[T2:.*]] = memref.alloc() : memref<16xindex> -// CHECK: %[[T3:.*]] = memref.cast %[[T2]] : memref<16xindex> to memref -// CHECK: %[[T4:.*]] = memref.alloc() : memref<16xindex> -// CHECK: %[[T5:.*]] = memref.cast %[[T4]] : memref<16xindex> to memref -// CHECK: %[[T6:.*]] = memref.alloc() : memref<16xf64> -// CHECK: %[[T7:.*]] = memref.cast %[[T6]] : memref<16xf64> to memref -// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T1]] : memref<3xindex>) -// CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<2xindex> -// CHECK: memref.store %[[C10]], %[[T0]][%[[C1]]] : memref<2xindex> -// CHECK: %[[P0:.*]] = sparse_tensor.push_back %[[T1]], %[[T3]] -// CHECK: %[[P1:.*]] = sparse_tensor.push_back %[[T1]], %[[P0]] -// CHECK: return %[[T0]], %[[T1]], %[[P1]], %[[T5]], %[[T7]] : -// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK-LABEL: func.func @sparse_alloc_csc( +// CHECK-SAME: %[[A0:.*]]: index) -> +// CHECK-SAME: (memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)>) +// CHECK-DAG: %[[A1:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[A2:.*]] = arith.constant 10 : i64 +// CHECK-DAG: %[[A3:.*]] = arith.constant 0 : i64 +// CHECK: %[[A4:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[A5:.*]] = memref.cast %[[A4]] +// CHECK: %[[A6:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[A7:.*]] = memref.cast %[[A6]] +// CHECK: %[[A8:.*]] = memref.alloc() : memref<16xf64> +// CHECK: %[[A9:.*]] = memref.cast %[[A8]] +// CHECK: %[[A10:.*]] = llvm.mlir.undef +// CHECK: %[[A11:.*]] = llvm.insertvalue %[[A3]], %[[A10]][1, 0] +// CHECK: %[[A12:.*]] = llvm.insertvalue %[[A3]], %[[A11]][1, 1] +// CHECK: %[[A13:.*]] = llvm.insertvalue %[[A3]], %[[A12]][1, 2] +// CHECK: %[[A14:.*]] = arith.index_cast %[[A0]] : index to i64 +// CHECK: %[[A15:.*]] = llvm.insertvalue %[[A14]], %[[A13]][0, 0] +// CHECK: %[[A16:.*]] = llvm.insertvalue %[[A2]], %[[A15]][0, 1] +// CHECK: %[[A17:.*]], %[[A18:.*]] = sparse_tensor.push_back %[[A1]], %[[A5]], %[[A1]] : index, memref, index +// CHECK: %[[A19:.*]] = arith.index_cast %[[A18]] : index to i64 +// CHECK: %[[A20:.*]] = llvm.insertvalue %[[A19]], %[[A16]][1, 0] +// CHECK: %[[A21:.*]], %[[A22:.*]] = sparse_tensor.push_back %[[A18]], %[[A17]], %[[A1]], %[[A0]] : index, memref, index, index +// CHECK: %[[A23:.*]] = arith.index_cast %[[A22]] : index to i64 +// CHECK: %[[A24:.*]] = llvm.insertvalue %[[A23]], %[[A20]][1, 0] +// CHECK: return %[[A21]], %[[A7]], %[[A9]], %[[A24]] : memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)> func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> { %0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSC> %1 = sparse_tensor.load %0 : tensor<10x?xf64, #CSC> return %1 : tensor<10x?xf64, #CSC> } -// CHECK-LABEL: func @sparse_alloc_3d() -> -// CHECK-SAME: memref<3xindex>, memref<1xindex>, memref -// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index -// CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index -// CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index -// CHECK-DAG: %[[C6000:.*]] = arith.constant 6000 : index -// CHECK: %[[A0:.*]] = memref.alloc() : memref<3xindex> -// CHECK: %[[A1:.*]] = memref.alloc() : memref<1xindex> -// CHECK: %[[AV:.*]] = memref.alloc() : memref<16xf64> -// CHECK: %[[A2:.*]] = memref.cast %[[AV]] : memref<16xf64> to memref -// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[A1]] : memref<1xindex>) -// CHECK: memref.store %[[C30]], %[[A0]][%[[C0]]] : memref<3xindex> -// CHECK: memref.store %[[C10]], %[[A0]][%[[C1]]] : memref<3xindex> -// CHECK: memref.store %[[C20]], %[[A0]][%[[C2]]] : memref<3xindex> -// CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[F0]], %[[C6000]] -// CHECK: return %[[A0]], %[[A1]], %[[P]] : -// CHECK-SAME: memref<3xindex>, memref<1xindex>, memref +// CHECK-LABEL: func.func @sparse_alloc_3d() +// CHECK-SAME: -> (memref, !llvm.struct<(array<3 x i64>, array<1 x i64>)>) { +// CHECK-DAG: %[[A0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[A1:.*]] = arith.constant 6000 : index +// CHECK-DAG: %[[A2:.*]] = arith.constant 20 : i64 +// CHECK-DAG: %[[A3:.*]] = arith.constant 10 : i64 +// CHECK-DAG: %[[A4:.*]] = arith.constant 30 : i64 +// CHECK-DAG: %[[A5:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[A6:.*]] = arith.constant 0 : i64 +// CHECK: %[[A7:.*]] = memref.alloc() : memref<16xf64> +// CHECK: %[[A8:.*]] = memref.cast %[[A7]] : memref<16xf64> to memref +// CHECK: %[[A9:.*]] = llvm.mlir.undef +// CHECK: %[[A10:.*]] = llvm.insertvalue %[[A6]], %[[A9]][1, 0] +// CHECK: %[[A11:.*]] = llvm.insertvalue %[[A4]], %[[A10]][0, 0] +// CHECK: %[[A12:.*]] = llvm.insertvalue %[[A3]], %[[A11]][0, 1] +// CHECK: %[[A13:.*]] = llvm.insertvalue %[[A2]], %[[A12]][0, 2] +// CHECK: %[[A14:.*]], %[[A15:.*]] = sparse_tensor.push_back %[[A0]], %[[A8]], %[[A5]], %[[A1]] : index, memref, f64, index +// CHECK: %[[A16:.*]] = arith.index_cast %[[A15]] : index to i64 +// CHECK: %[[A17:.*]] = llvm.insertvalue %[[A16]], %[[A13]][1, 0] +// CHECK: return %[[A14]], %[[A17]] : memref, !llvm.struct<(array<3 x i64>, array<1 x i64>)> func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> { %0 = bufferization.alloc_tensor() : tensor<10x20x30xf64, #Dense3D> %1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D> @@ -364,13 +347,9 @@ // CHECK-LABEL: func.func @sparse_expansion3( // CHECK-SAME: %[[D0:.*]]: index, // CHECK-SAME: %{{.*}}: index) -> memref { -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[S0:.*]] = memref.alloc() : memref<2xindex> -// CHECK: memref.store %[[D0]], %[[S0]]{{\[}}%[[C1]]] : memref<2xindex> -// CHECK: %[[D1:.*]] = memref.load %[[S0]]{{\[}}%[[C1]]] : memref<2xindex> -// CHECK: %[[V:.*]] = memref.alloc(%[[D1]]) : memref -// CHECK: %[[B:.*]] = memref.alloc(%[[D1]]) : memref -// CHECK: %[[D:.*]] = memref.alloc(%[[D1]]) : memref +// CHECK: %[[V:.*]] = memref.alloc(%[[D0]]) : memref +// CHECK: %[[B:.*]] = memref.alloc(%[[D0]]) : memref +// CHECK: %[[D:.*]] = memref.alloc(%[[D0]]) : memref // CHECK: linalg.fill ins(%{{.*}} : f64) outs(%[[V]] : memref) // CHECK: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref) // CHECK: return %[[D]] : memref @@ -382,45 +361,39 @@ } // CHECK-LABEL: func.func private @_insert_C_100_f64_0_0( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: f64) -// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[PV]] +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[A5:.*4]]: index, +// CHECK-SAME: %[[A6:.*5]]: f64) // -// CHECK-LABEL: func @sparse_compression_1d( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-LABEL: func.func @sparse_compression_1d( +// CHECK-SAME: %[[A0:.*0]]: memref, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, // CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: index) -// CHECK-DAG: %[[B0:.*]] = arith.constant false -// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref -// CHECK: %[[R:.*]]:5 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] -// CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<1xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref -// CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref -// CHECK: %[[C:.*]]:5 = func.call @_insert_C_100_f64_0_0(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[INDEX]], %[[VAL]]) -// CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref -// CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref -// CHECK: scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: index) -> (memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)>) { +// CHECK-DAG: %[[A8:.*]] = arith.constant false +// CHECK-DAG: %[[A9:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[A10:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[A11:.*]] = arith.constant 0 : index +// CHECK: sparse_tensor.sort %[[A7]], %[[A6]] : memref +// CHECK: %[[A12:.*]]:4 = scf.for %[[A13:.*]] = %[[A11]] to %[[A7]] step %[[A10]] iter_args(%[[A14:.*]] = %[[A0]], %[[A15:.*]] = %[[A1]], %[[A16:.*]] = %[[A2]], %[[A17:.*]] = %[[A3]]) +// CHECK: %[[A18:.*]] = memref.load %[[A6]]{{\[}}%[[A13]]] : memref +// CHECK: %[[A19:.*]] = memref.load %[[A4]]{{\[}}%[[A18]]] : memref +// CHECK: %[[A20:.*]]:4 = func.call @_insert_C_100_f64_0_0(%[[A14]], %[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A19]]) +// CHECK: memref.store %[[A9]], %[[A4]]{{\[}}%[[A18]]] : memref +// CHECK: memref.store %[[A8]], %[[A5]]{{\[}}%[[A18]]] : memref +// CHECK: scf.yield %[[A20]]#0, %[[A20]]#1, %[[A20]]#2, %[[A20]]#3 // CHECK: } -// CHECK: memref.dealloc %[[A5]] : memref -// CHECK: memref.dealloc %[[A6]] : memref -// CHECK: memref.dealloc %[[A7]] : memref -// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK: memref.dealloc %[[A4]] : memref +// CHECK: memref.dealloc %[[A5]] : memref +// CHECK: memref.dealloc %[[A6]] : memref +// CHECK: return %[[A21:.*]]#0, %[[A21]]#1, %[[A21]]#2, %[[A21]]#3 : memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)> func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>, %values: memref, %filled: memref, @@ -433,47 +406,54 @@ } // CHECK-LABEL: func.func private @_insert_D_C_8_8_f64_64_32( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: index, -// CHECK-SAME: %[[A7:.*7]]: f64) -// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A7]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[PV]] +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[A5:.*4]]: index, +// CHECK-SAME: %[[A6:.*5]]: index, +// CHECK-SAME: %[[A7:.*6]]: f64) // -// CHECK-LABEL: func @sparse_compression( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-LABEL: func.func @sparse_compression( +// CHECK-SAME: %[[A0:.*0]]: memref, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>, // CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: index, -// CHECK-SAME: %[[A9:.*9]]: index) -// CHECK-DAG: %[[B0:.*]] = arith.constant false -// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref -// CHECK: %[[R:.*]]:5 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] -// CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref -// CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref -// CHECK: %[[C:.*]]:5 = func.call @_insert_D_C_8_8_f64_64_32(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[A9]], %[[INDEX]], %[[VAL]]) -// CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref -// CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref -// CHECK: scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: index, +// CHECK-SAME: %[[A8:.*8]]: index) -> (memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)>) { +// CHECK-DAG: %[[A9:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[A10:.*]] = arith.constant false +// CHECK-DAG: %[[A11:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[A12:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[A13:.*]] = arith.constant 0 : index +// CHECK: sparse_tensor.sort %[[A7]], %[[A6]] : memref +// CHECK: %[[A14:.*]]:4 = scf.for %[[A15:.*]] = %[[A13]] to %[[A7]] step %[[A12]] iter_args(%[[A16:.*]] = %[[A0]], %[[A17:.*]] = %[[A1]], %[[A18:.*]] = %[[A2]], %[[A19:.*]] = %[[A3]]) +// CHECK: %[[A20:.*]] = memref.load %[[A6]]{{\[}}%[[A15]]] : memref +// CHECK: %[[A21:.*]] = memref.load %[[A4]]{{\[}}%[[A20]]] : memref +// CHECK: %[[A22:.*]]:4 = func.call @_insert_D_C_8_8_f64_64_32(%[[A16]], %[[A17]], %[[A18]], %[[A19]], %[[A8]], %[[A20]], %[[A21]]) +// CHECK: memref.store %[[A11]], %[[A4]]{{\[}}%[[A20]]] : memref +// CHECK: memref.store %[[A10]], %[[A5]]{{\[}}%[[A20]]] : memref +// CHECK: scf.yield %[[A22]]#0, %[[A22]]#1, %[[A22]]#2, %[[A22]]#3 // CHECK: } -// CHECK: memref.dealloc %[[A5]] : memref -// CHECK: memref.dealloc %[[A6]] : memref -// CHECK: memref.dealloc %[[A7]] : memref -// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 -// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK: memref.dealloc %[[A4]] : memref +// CHECK: memref.dealloc %[[A5]] : memref +// CHECK: memref.dealloc %[[A6]] : memref +// CHECK: %[[A23:.*]] = llvm.extractvalue %[[A24:.*]]#3[1, 0] +// CHECK: %[[A25:.*]] = arith.index_cast %[[A23]] : i64 to index +// CHECK: %[[A26:.*]] = memref.load %[[A24]]#0{{\[}}%[[A13]]] : memref +// CHECK: %[[A27:.*]] = scf.for %[[A28:.*]] = %[[A12]] to %[[A25]] step %[[A12]] iter_args(%[[A29:.*]] = %[[A26]]) -> (i32) { +// CHECK: %[[A30:.*]] = memref.load %[[A24]]#0{{\[}}%[[A28]]] : memref +// CHECK: %[[A31:.*]] = arith.cmpi eq, %[[A30]], %[[A9]] : i32 +// CHECK: %[[A32:.*]] = arith.select %[[A31]], %[[A29]], %[[A30]] : i32 +// CHECK: scf.if %[[A31]] { +// CHECK: memref.store %[[A29]], %[[A24]]#0{{\[}}%[[A28]]] : memref +// CHECK: } +// CHECK: scf.yield %[[A32]] : i32 +// CHECK: } +// CHECK: return %[[A24]]#0, %[[A24]]#1, %[[A24]]#2, %[[A24]]#3 : memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)> func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>, %values: memref, %filled: memref, @@ -487,47 +467,52 @@ } // CHECK-LABEL: func.func private @_insert_D_C_8_8_f64_0_0( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: index, -// CHECK-SAME: %[[A7:.*7]]: f64) -// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A7]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[PV]] +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[A5:.*4]]: index, +// CHECK-SAME: %[[A6:.*5]]: index, +// CHECK-SAME: %[[A7:.*6]]: f64) // -// CHECK-LABEL: func @sparse_compression_unordered( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-LABEL: func.func @sparse_compression_unordered( +// CHECK-SAME: %[[A0:.*0]]: memref, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>, // CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: index, -// CHECK-SAME: %[[A9:.*9]]: index) -// CHECK-DAG: %[[B0:.*]] = arith.constant false -// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-NOT: sparse_tensor.sort -// CHECK: %[[R:.*]]:5 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] -// CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref -// CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref -// CHECK: %[[C:.*]]:5 = func.call @_insert_D_C_8_8_f64_0_0(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[A9]], %[[INDEX]], %[[VAL]]) -// CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref -// CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref -// CHECK: scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: index, +// CHECK-SAME: %[[A8:.*8]]: index) +// CHECK-DAG: %[[A9:.*]] = arith.constant false +// CHECK-DAG: %[[A10:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[A11:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[A12:.*]] = arith.constant 1 : index +// CHECK: %[[A13:.*]]:4 = scf.for %[[A14:.*]] = %[[A11]] to %[[A7]] step %[[A12]] iter_args(%[[A15:.*]] = %[[A0]], %[[A16:.*]] = %[[A1]], %[[A17:.*]] = %[[A2]], %[[A18:.*]] = %[[A3]]) -> (memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)>) { +// CHECK: %[[A19:.*]] = memref.load %[[A6]]{{\[}}%[[A14]]] : memref +// CHECK: %[[A20:.*]] = memref.load %[[A4]]{{\[}}%[[A19]]] : memref +// CHECK: %[[A21:.*]]:4 = func.call @_insert_D_C_8_8_f64_0_0(%[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A8]], %[[A19]], %[[A20]]) +// CHECK: memref.store %[[A10]], %[[A4]]{{\[}}%[[A19]]] : memref +// CHECK: memref.store %[[A9]], %[[A5]]{{\[}}%[[A19]]] : memref +// CHECK: scf.yield %[[A21]]#0, %[[A21]]#1, %[[A21]]#2, %[[A21]]#3 +// CHECK: } +// CHECK: memref.dealloc %[[A4]] : memref +// CHECK: memref.dealloc %[[A5]] : memref +// CHECK: memref.dealloc %[[A6]] : memref +// CHECK: %[[A22:.*]] = llvm.extractvalue %[[A23:.*]]#3[1, 0] +// CHECK: %[[A24:.*]] = arith.index_cast %[[A22]] : i64 to index +// CHECK: %[[A25:.*]] = memref.load %[[A23]]#0{{\[}}%[[A11]]] : memref +// CHECK: %[[A26:.*]] = scf.for %[[A27:.*]] = %[[A12]] to %[[A24]] step %[[A12]] iter_args(%[[A28:.*]] = %[[A25]]) -> (index) { +// CHECK: %[[A29:.*]] = memref.load %[[A23]]#0{{\[}}%[[A27]]] : memref +// CHECK: %[[A30:.*]] = arith.cmpi eq, %[[A29]], %[[A11]] : index +// CHECK: %[[A31:.*]] = arith.select %[[A30]], %[[A28]], %[[A29]] : index +// CHECK: scf.if %[[A30]] { +// CHECK: memref.store %[[A28]], %[[A23]]#0{{\[}}%[[A27]]] : memref +// CHECK: } +// CHECK: scf.yield %[[A31]] : index // CHECK: } -// CHECK: memref.dealloc %[[A5]] : memref -// CHECK: memref.dealloc %[[A6]] : memref -// CHECK: memref.dealloc %[[A7]] : memref -// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 -// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK: return %[[A23]]#0, %[[A23]]#1, %[[A23]]#2, %[[A23]]#3 : memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)> func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>, %values: memref, %filled: memref, @@ -541,26 +526,22 @@ } // CHECK-LABEL: func.func private @_insert_C_128_f64_0_0( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: f64) -// CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[P]] : -// CHECK: func @sparse_insert( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: f64) -// CHECK: %[[R:.*]]:5 = call @_insert_C_128_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]]) -// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[A5:.*4]]: index, +// CHECK-SAME: %[[A6:.*5]]: f64) +// +// CHECK-LABEL: func @sparse_insert( +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[A5:.*4]]: index, +// CHECK-SAME: %[[A6:.*5]]: f64) +// CHECK: %[[R:.*]]:4 = call @_insert_C_128_f64_0_0(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]]) +// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3 func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> { %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV> %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SV> @@ -568,26 +549,22 @@ } // CHECK-LABEL: func.func private @_insert_C_128_f64_64_32( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: f64) -// CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[P]] : -// CHECK: func @sparse_insert_typed( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: f64) -// CHECK: %[[R:.*]]:5 = call @_insert_C_128_f64_64_32(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]]) -// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[A5:.*]]: index, +// CHECK-SAME: %[[A6:.*]]: f64) +// +// CHECK-LABEL: func @sparse_insert_typed( +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[A5:.*]]: index, +// CHECK-SAME: %[[A6:.*]]: f64) +// CHECK: %[[R:.*]]:4 = call @_insert_C_128_f64_64_32(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]]) +// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3 func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> { %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector> %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SparseVector> @@ -595,14 +572,13 @@ } // CHECK-LABEL: func.func @sparse_nop_convert( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>) +// CHECK: return %[[A1]], %[[A2]], %[[A3]], %[[A4]] : +// CHECK-SAME: memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)> func.func @sparse_nop_convert(%arg0: tensor<32xf32, #SparseVector>) -> tensor { %0 = sparse_tensor.convert %arg0 : tensor<32xf32, #SparseVector> to tensor return %0 : tensor -} \ No newline at end of file +} diff --git a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir --- a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir @@ -2,28 +2,35 @@ #SV = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> -// CHECK-LABEL: func @sparse_alloc_sparse_vector( -// CHECK-SAME: %[[A:.*]]: index) -> -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK: %[[T0:.*]] = memref.alloc() : memref<1xindex> -// CHECK: %[[T1:.*]] = memref.alloc() : memref<3xindex> -// CHECK: %[[T2:.*]] = memref.alloc() : memref<16xindex> -// CHECK: %[[T3:.*]] = memref.cast %[[T2]] : memref<16xindex> to memref -// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T2]] : memref<16xindex>) -// CHECK: %[[T4:.*]] = memref.alloc() : memref<16xindex> -// CHECK: %[[T5:.*]] = memref.cast %[[T4]] : memref<16xindex> to memref -// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T4]] : memref<16xindex>) -// CHECK: %[[T6:.*]] = memref.alloc() : memref<16xf64> -// CHECK: %[[T7:.*]] = memref.cast %[[T6]] : memref<16xf64> to memref -// CHECK: linalg.fill ins(%[[F0]] : f64) outs(%[[T6]] : memref<16xf64>) -// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T1]] : memref<3xindex>) -// CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<1xindex> -// CHECK: %[[P0:.*]] = sparse_tensor.push_back %[[T1]], %[[T3]] -// CHECK: %[[P1:.*]] = sparse_tensor.push_back %[[T1]], %[[P0]] -// CHECK: return %[[T0]], %[[T1]], %[[P1]], %[[T5]], %[[T7]] : +// CHECK-LABEL: func.func @sparse_alloc_sparse_vector( +// CHECK-SAME: %[[VAL_0:.*]]: index) -> +// CHECK-SAME: (memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)>) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[I0:.*]] = arith.constant 0 : i64 +// CHECK: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[VAL_5:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[VAL_6:.*]] = memref.cast %[[VAL_5]] : memref<16xindex> to memref +// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[VAL_5]] : memref<16xindex>) +// CHECK: %[[VAL_7:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[VAL_8:.*]] = memref.cast %[[VAL_7]] : memref<16xindex> to memref +// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[VAL_7]] : memref<16xindex>) +// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<16xf64> +// CHECK: %[[VAL_10:.*]] = memref.cast %[[VAL_9]] : memref<16xf64> to memref +// CHECK: linalg.fill ins(%[[F0]] : f64) outs(%[[VAL_9]] : memref<16xf64>) +// CHECK: %[[MD1:.*]] = llvm.mlir.undef : !llvm.struct<(array<1 x i64>, array<3 x i64>)> +// CHECK: %[[MD2:.*]] = llvm.insertvalue %[[I0]], %[[MD1]][1, 0] +// CHECK: %[[MD3:.*]] = llvm.insertvalue %[[I0]], %[[MD2]][1, 1] +// CHECK: %[[MD4:.*]] = llvm.insertvalue %[[I0]], %[[MD3]][1, 2] +// CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_0]] : index to i64 +// CHECK: %[[MD5:.*]] = llvm.insertvalue %[[VAL_15]], %[[MD4]][0, 0] +// CHECK: %[[VAL_17:.*]], %[[VAL_18:.*]] = sparse_tensor.push_back %[[C0]], %[[VAL_6]], %[[C0]] +// CHECK: %[[VAL_19:.*]] = arith.index_cast %[[VAL_18]] : index to i64 +// CHECK: %[[MD6:.*]] = llvm.insertvalue %[[VAL_19]], %[[MD5]][1, 0] +// CHECK: %[[VAL_21:.*]], %[[VAL_22:.*]] = sparse_tensor.push_back %[[VAL_18]], %[[VAL_17]], %[[C0]], %[[C1]] +// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_22]] : index to i64 +// CHECK: %[[MD:.*]] = llvm.insertvalue %[[VAL_23]], %[[MD6]][1, 0] +// CHECK: return %[[VAL_21]], %[[VAL_8]], %[[VAL_10]], %[[MD]] func.func @sparse_alloc_sparse_vector(%arg0: index) -> tensor { %0 = bufferization.alloc_tensor(%arg0) : tensor %1 = sparse_tensor.load %0 : tensor diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -192,19 +192,19 @@ // ----- -func.func @sparse_push_back(%arg0: memref, %arg1: memref, %arg2: f32) -> memref { +func.func @sparse_push_back(%arg0: index, %arg1: memref, %arg2: f32) -> (memref, index) { // expected-error@+1 {{'sparse_tensor.push_back' op failed to verify that value type matches element type of inBuffer}} - %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f32 - return %0 : memref + %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2 : index, memref, f32 + return %0#0, %0#1 : memref, index } // ----- -func.func @sparse_push_back_n(%arg0: memref, %arg1: memref, %arg2: f32) -> memref { +func.func @sparse_push_back_n(%arg0: index, %arg1: memref, %arg2: f32) -> (memref, index) { %c0 = arith.constant 0: index // expected-error@+1 {{'sparse_tensor.push_back' op n must be not less than 1}} - %0 = sparse_tensor.push_back %arg0, %arg1, %arg2, %c0 {idx = 2 : index} : memref, memref, f32, index - return %0 : memref + %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2, %c0 : index, memref, f32, index + return %0#0, %0#1 : memref, index } // ----- diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -201,41 +201,41 @@ // ----- // CHECK-LABEL: func @sparse_push_back( -// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[A:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref, -// CHECK-SAME: %[[C:.*]]: f64) -> memref { -// CHECK: %[[D:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref, memref, f64 +// CHECK-SAME: %[[C:.*]]: f64) -> (memref, index) { +// CHECK: %[[D:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]] : index, memref, f64 // CHECK: return %[[D]] -func.func @sparse_push_back(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { - %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 - return %0 : memref +func.func @sparse_push_back(%arg0: index, %arg1: memref, %arg2: f64) -> (memref, index) { + %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2 : index, memref, f64 + return %0#0, %0#1 : memref, index } // ----- // CHECK-LABEL: func @sparse_push_back_inbound( -// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[A:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref, -// CHECK-SAME: %[[C:.*]]: f64) -> memref { -// CHECK: %[[D:.*]] = sparse_tensor.push_back inbounds %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref, memref, f64 +// CHECK-SAME: %[[C:.*]]: f64) -> (memref, index) { +// CHECK: %[[D:.*]] = sparse_tensor.push_back inbounds %[[A]], %[[B]], %[[C]] : index, memref, f64 // CHECK: return %[[D]] -func.func @sparse_push_back_inbound(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { - %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 - return %0 : memref +func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref, %arg2: f64) -> (memref, index) { + %0:2 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 : index, memref, f64 + return %0#0, %0#1 : memref, index } // ----- // CHECK-LABEL: func @sparse_push_back_n( -// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[A:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref, // CHECK-SAME: %[[C:.*]]: f64, -// CHECK-SAME: %[[D:.*]]: index) -> memref { -// CHECK: %[[E:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]], %[[D]] {idx = 2 : index} : memref, memref, f64, index +// CHECK-SAME: %[[D:.*]]: index) -> (memref, index) { +// CHECK: %[[E:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]], %[[D]] : index, memref, f64, index // CHECK: return %[[E]] -func.func @sparse_push_back_n(%arg0: memref, %arg1: memref, %arg2: f64, %arg3: index) -> memref { - %0 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 {idx = 2 : index} : memref, memref, f64, index - return %0 : memref +func.func @sparse_push_back_n(%arg0: index, %arg1: memref, %arg2: f64, %arg3: index) -> (memref, index) { + %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 : index, memref, f64, index + return %0#0, %0#1 : memref, index } // ----- diff --git a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir --- a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir @@ -1,24 +1,23 @@ // RUN: mlir-opt %s -sparse-tensor-codegen -cse | FileCheck %s #SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> -// CHECK-LABEL: func @for( -// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[MEM_SIZE:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[POINTER:.*2]]: memref, -// CHECK-SAME: %[[INDICES:.*3]]: memref, -// CHECK-SAME: %[[VALUE:.*4]]: memref, -// CHECK-SAME: %[[LB:.*5]]: index, -// CHECK-SAME: %[[UB:.*6]]: index, -// CHECK-SAME: %[[STEP:.*7]]: index) -// CHECK: %[[OUT:.*]]:5 = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args( -// CHECK-SAME: %[[SIZE:.*]] = %[[DIM_SIZE]], -// CHECK-SAME: %[[MEM:.*]] = %[[MEM_SIZE]], -// CHECK-SAME: %[[PTR:.*]] = %[[POINTER]], -// CHECK-SAME: %[[IDX:.*]] = %[[INDICES]], -// CHECK-SAME: %[[VAL:.*]] = %[[VALUE]]) -// CHECK: scf.yield %[[SIZE]], %[[MEM]], %[[PTR]], %[[IDX]], %[[VAL]] : memref<1xindex>, memref<3xindex>, memref, memref, memref -// CHECK: } -// CHECK: return %[[OUT]]#0, %[[OUT]]#1, %[[OUT]]#2, %[[OUT]]#3, %[[OUT]]#4 : memref<1xindex>, memref<3xindex>, memref, memref, memref + +// CHECK-LABEL: func.func @for( +// CHECK-SAME: %[[VAL_1:.*0]]: memref, +// CHECK-SAME: %[[VAL_2:.*1]]: memref, +// CHECK-SAME: %[[VAL_3:.*2]]: memref, +// CHECK-SAME: %[[VAL_4:.*3]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[VAL_5:.*4]]: index, +// CHECK-SAME: %[[VAL_6:.*5]]: index, +// CHECK-SAME: %[[VAL_7:.*6]]: index) -> (memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)>) { +// CHECK: %[[VAL_8:.*]]:4 = scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_6]] step %[[VAL_7]] iter_args( +// CHECK-SAME: %[[VAL_11:.*]] = %[[VAL_1]], +// CHECK-SAME: %[[VAL_12:.*]] = %[[VAL_2]], +// CHECK-SAME: %[[VAL_13:.*]] = %[[VAL_3]], +// CHECK-SAME: %[[VAL_14:.*]] = %[[VAL_4]]) +// CHECK: scf.yield %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]] : +// CHECK: } +// CHECK: return %[[VAL_8]]#0, %[[VAL_8]]#1, %[[VAL_8]]#2, %[[VAL_8]]#3 func.func @for(%in: tensor<1024xf32, #SparseVector>, %lb: index, %ub: index, %step: index) -> tensor<1024xf32, #SparseVector> { %1 = scf.for %i = %lb to %ub step %step iter_args(%vin = %in) @@ -28,26 +27,23 @@ return %1 : tensor<1024xf32, #SparseVector> } - -// CHECK-LABEL: func @if( -// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[MEM_SIZE:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[POINTER:.*2]]: memref, -// CHECK-SAME: %[[INDICES:.*3]]: memref, -// CHECK-SAME: %[[VALUE:.*4]]: memref, -// CHECK-SAME: %[[DIM_SIZE_1:.*5]]: memref<1xindex>, -// CHECK-SAME: %[[MEM_SIZE_1:.*6]]: memref<3xindex>, -// CHECK-SAME: %[[POINTER_1:.*7]]: memref, -// CHECK-SAME: %[[INDICES_1:.*8]]: memref, -// CHECK-SAME: %[[VALUE_1:.*9]]: memref, -// CHECK-SAME: %[[I1:.*10]]: i1) -> -// CHECK-SAME: (memref<1xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: %[[SV:.*]]:5 = scf.if %[[I1]] -> (memref<1xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: scf.yield %[[DIM_SIZE]], %[[MEM_SIZE]], %[[POINTER]], %[[INDICES]], %[[VALUE]] : memref<1xindex>, memref<3xindex>, memref, memref, memref -// CHECK: } else { -// CHECK: scf.yield %[[DIM_SIZE_1]], %[[MEM_SIZE_1]], %[[POINTER_1]], %[[INDICES_1]], %[[VALUE_1]] : memref<1xindex>, memref<3xindex>, memref, memref, memref -// CHECK: } -// CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4 : memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-LABEL: func.func @if( +// CHECK-SAME: %[[VAL_1:.*0]]: memref, +// CHECK-SAME: %[[VAL_2:.*1]]: memref, +// CHECK-SAME: %[[VAL_3:.*2]]: memref, +// CHECK-SAME: %[[VAL_4:.*3]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[VAL_6:.*4]]: memref, +// CHECK-SAME: %[[VAL_7:.*5]]: memref, +// CHECK-SAME: %[[VAL_8:.*6]]: memref, +// CHECK-SAME: %[[VAL_9:.*7]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[VAL_10:.*]]: i1) +// CHECK: %[[VAL_11:.*]]:4 = scf.if %[[VAL_10]] +// CHECK: scf.yield %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] +// CHECK: } else { +// CHECK: scf.yield %[[VAL_6]], %[[VAL_7]], %[[VAL_8]], %[[VAL_9]] +// CHECK: } +// CHECK: return %[[VAL_11]]#0, %[[VAL_11]]#1, %[[VAL_11]]#2, %[[VAL_11]]#3 : +// CHECK-SAME: memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)> func.func @if(%t: tensor<1024xf32, #SparseVector>, %f: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024xf32, #SparseVector> { @@ -59,26 +55,28 @@ return %1 : tensor<1024xf32, #SparseVector> } -// CHECK-LABEL: func @while( -// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[MEM_SIZE:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[POINTER:.*2]]: memref, -// CHECK-SAME: %[[INDICES:.*3]]: memref, -// CHECK-SAME: %[[VALUE:.*4]]: memref, -// CHECK-SAME: %[[I1:.*5]]: i1) -> -// CHECK-SAME: (memref<1xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: %[[SV:.*]]:5 = scf.while ( -// CHECK-SAME: %[[TMP_DIM:.*]] = %[[DIM_SIZE]], -// CHECK-SAME: %[[TMP_MEM:.*]] = %[[MEM_SIZE]], -// CHECK-SAME: %[[TMP_PTR:.*]] = %[[POINTER]], -// CHECK-SAME: %[[TMP_IND:.*]] = %[[INDICES]], -// CHECK-SAME: %[[TMP_VAL:.*]] = %[[VALUE]]) -// CHECK: scf.condition(%[[I1]]) %[[TMP_DIM]], %[[TMP_MEM]], %[[TMP_PTR]], %[[TMP_IND]], %[[TMP_VAL]] : memref<1xindex>, memref<3xindex>, memref, memref, memref -// CHECK: } do { -// CHECK: ^bb0(%[[TMP_DIM]]: memref<1xindex>, %[[TMP_MEM]]: memref<3xindex>, %[[TMP_PTR]]: memref, %[[TMP_IND]]: memref, %[[TMP_VAL]]: memref): -// CHECK: scf.yield %[[TMP_DIM]], %[[TMP_MEM]], %[[TMP_PTR]], %[[TMP_IND]], %[[TMP_VAL]] : memref<1xindex>, memref<3xindex>, memref, memref, memref -// CHECK: } -// CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4 : memref<1xindex>, memref<3xindex>, memref, memref, memref + +// CHECK-LABEL: func.func @while( +// CHECK-SAME: %[[VAL_1:.*0]]: memref, +// CHECK-SAME: %[[VAL_2:.*1]]: memref, +// CHECK-SAME: %[[VAL_3:.*2]]: memref, +// CHECK-SAME: %[[VAL_4:.*3]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[VAL_5:.*4]]: i1) +// CHECK: %[[VAL_6:.*]]:4 = scf.while ( +// CHECK-SAME: %[[VAL_8:.*]] = %[[VAL_1]], +// CHECK-SAME: %[[VAL_9:.*]] = %[[VAL_2]], +// CHECK-SAME: %[[VAL_10:.*]] = %[[VAL_3]], +// CHECK-SAME: %[[VAL_11:.*]] = %[[VAL_4]]) +// CHECK: scf.condition(%[[VAL_5]]) %[[VAL_8]], %[[VAL_9]], %[[VAL_10]], %[[VAL_11]] +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_13:.*5]]: memref, +// CHECK-SAME: %[[VAL_14:.*6]]: memref, +// CHECK-SAME: %[[VAL_15:.*7]]: memref, +// CHECK-SAME: %[[VAL_16:.*8]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>): +// CHECK: scf.yield %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] +// CHECK: } +// CHECK: return %[[VAL_6]]#0, %[[VAL_6]]#1, %[[VAL_6]]#2, %[[VAL_6]]#3 : +// CHECK-SAME: memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)> func.func @while(%arg0: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024xf32, #SparseVector> { %0 = scf.while (%in = %arg0) : (tensor<1024xf32, #SparseVector>) -> tensor<1024xf32, #SparseVector> { scf.condition(%c) %in : tensor<1024xf32, #SparseVector> diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir @@ -13,133 +13,147 @@ // Computes C = A x B with all matrices sparse (SpMSpM) in CSR. // // CHECK-LABEL: func.func private @_insert_D_C_4_4_f64_0_0( -// CHECK-SAME: %[[VAL_0:.*]]: memref<2xindex>, -// CHECK-SAME: %[[VAL_1:.*]]: memref<3xindex>, -// CHECK-SAME: %[[VAL_2:[^ ]+]]: memref, -// CHECK-SAME: %[[VAL_3:.*]]: memref, -// CHECK-SAME: %[[VAL_4:.*]]: memref, -// CHECK-SAME: %[[VAL_5:[^ ]+]]: index, -// CHECK-SAME: %[[VAL_6:.*]]: index, -// CHECK-SAME: %[[VAL_7:.*]]: f64) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant false -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_5]], %[[VAL_9]] : index -// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref -// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref -// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_9]]] : memref<3xindex> -// CHECK: %[[VAL_14:.*]] = arith.subi %[[VAL_12]], %[[VAL_9]] : index -// CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_12]] : index +// CHECK-SAME: %[[VAL_0:.*0]]: memref, +// CHECK-SAME: %[[VAL_1:.*1]]: memref, +// CHECK-SAME: %[[VAL_2:.*2]]: memref, +// CHECK-SAME: %[[VAL_3:.*3]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[VAL_4:.*4]]: index, +// CHECK-SAME: %[[VAL_5:.*5]]: index, +// CHECK-SAME: %[[VAL_6:.*6]]: f64) +// CHECK: %[[VAL_7:.*]] = arith.constant false +// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_4]], %[[VAL_8]] : index +// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref +// CHECK: %[[VAL_12:.*]] = llvm.extractvalue %[[VAL_3]][1, 1] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : i64 to index +// CHECK: %[[VAL_14:.*]] = arith.subi %[[VAL_11]], %[[VAL_8]] : index +// CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_10]], %[[VAL_11]] : index // CHECK: %[[VAL_16:.*]] = scf.if %[[VAL_15]] -> (i1) { -// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_14]]] : memref -// CHECK: %[[VAL_18:.*]] = arith.cmpi eq, %[[VAL_17]], %[[VAL_6]] : index +// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_14]]] : memref +// CHECK: %[[VAL_18:.*]] = arith.cmpi eq, %[[VAL_17]], %[[VAL_5]] : index // CHECK: scf.yield %[[VAL_18]] : i1 // CHECK: } else { -// CHECK: memref.store %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref -// CHECK: scf.yield %[[VAL_8]] : i1 +// CHECK: memref.store %[[VAL_13]], %[[VAL_0]]{{\[}}%[[VAL_4]]] : memref +// CHECK: scf.yield %[[VAL_7]] : i1 // CHECK: } -// CHECK: %[[VAL_19:.*]] = scf.if %[[VAL_20:.*]] -> (memref) { -// CHECK: scf.yield %[[VAL_3]] : memref +// CHECK: %[[VAL_19:.*]]:2 = scf.if %[[VAL_20:.*]] -> (memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)>) { +// CHECK: scf.yield %[[VAL_1]], %[[VAL_3]] : memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)> // CHECK: } else { -// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_13]], %[[VAL_9]] : index -// CHECK: memref.store %[[VAL_21]], %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref -// CHECK: %[[VAL_22:.*]] = sparse_tensor.push_back %[[VAL_1]], %[[VAL_3]], %[[VAL_6]] {idx = 1 : index} : memref<3xindex>, memref, index -// CHECK: scf.yield %[[VAL_22]] : memref +// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_13]], %[[VAL_8]] : index +// CHECK: memref.store %[[VAL_21]], %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref +// CHECK: %[[VAL_22:.*]], %[[VAL_23:.*]] = sparse_tensor.push_back %[[VAL_13]], %[[VAL_1]], %[[VAL_5]] : index, memref, index +// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : index to i64 +// CHECK: %[[VAL_25:.*]] = llvm.insertvalue %[[VAL_24]], %[[VAL_3]][1, 1] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: scf.yield %[[VAL_22]], %[[VAL_25]] : memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)> // CHECK: } -// CHECK: %[[VAL_23:.*]] = sparse_tensor.push_back %[[VAL_1]], %[[VAL_4]], %[[VAL_7]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_24:.*]], %[[VAL_23]] : memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK: %[[VAL_26:.*]] = llvm.extractvalue %[[VAL_27:.*]]#1[1, 2] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_28:.*]] = arith.index_cast %[[VAL_26]] : i64 to index +// CHECK: %[[VAL_29:.*]], %[[VAL_30:.*]] = sparse_tensor.push_back %[[VAL_28]], %[[VAL_2]], %[[VAL_6]] : index, memref, f64 +// CHECK: %[[VAL_31:.*]] = arith.index_cast %[[VAL_30]] : index to i64 +// CHECK: %[[VAL_32:.*]] = llvm.insertvalue %[[VAL_31]], %[[VAL_27]]#1[1, 2] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: return %[[VAL_0]], %[[VAL_27]]#0, %[[VAL_29]], %[[VAL_32]] : memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)> // CHECK: } // CHECK-LABEL: func.func @matmul( -// CHECK-SAME: %[[VAL_0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[VAL_1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[VAL_2:.*2]]: memref, -// CHECK-SAME: %[[VAL_3:.*3]]: memref, -// CHECK-SAME: %[[VAL_4:.*4]]: memref, -// CHECK-SAME: %[[VAL_5:.*5]]: memref<2xindex>, -// CHECK-SAME: %[[VAL_6:.*6]]: memref<3xindex>, -// CHECK-SAME: %[[VAL_7:.*7]]: memref, -// CHECK-SAME: %[[VAL_8:.*8]]: memref, -// CHECK-SAME: %[[VAL_9:.*9]]: memref) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_14:.*]] = arith.constant false -// CHECK-DAG: %[[VAL_15:.*]] = arith.constant true -// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<2xindex> -// CHECK: %[[VAL_17:.*]] = memref.alloc() : memref<3xindex> +// CHECK-SAME: %[[VAL_0:.*0]]: memref, +// CHECK-SAME: %[[VAL_1:.*1]]: memref, +// CHECK-SAME: %[[VAL_2:.*2]]: memref, +// CHECK-SAME: %[[VAL_3:.*3]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[VAL_4:.*4]]: memref, +// CHECK-SAME: %[[VAL_5:.*5]]: memref, +// CHECK-SAME: %[[VAL_6:.*6]]: memref, +// CHECK-SAME: %[[VAL_7:.*7]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>) +// CHECK: %[[VAL_8:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_10:.*]] = arith.constant 4 : i64 +// CHECK: %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[VAL_12:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_13:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_14:.*]] = arith.constant false +// CHECK: %[[VAL_15:.*]] = arith.constant true +// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[VAL_17:.*]] = memref.cast %[[VAL_16]] : memref<16xindex> to memref // CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<16xindex> // CHECK: %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<16xindex> to memref -// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<16xindex> -// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<16xindex> to memref -// CHECK: %[[VAL_22:.*]] = memref.alloc() : memref<16xf64> -// CHECK: %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<16xf64> to memref -// CHECK: linalg.fill ins(%[[VAL_12]] : index) outs(%[[VAL_17]] : memref<3xindex>) -// CHECK: memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_12]]] : memref<2xindex> -// CHECK: memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_13]]] : memref<2xindex> -// CHECK: %[[VAL_24:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_19]], %[[VAL_12]] {idx = 0 : index} : memref<3xindex>, memref, index -// CHECK: %[[VAL_25:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_24]], %[[VAL_12]], %[[VAL_10]] {idx = 0 : index} : memref<3xindex>, memref, index, index -// CHECK: %[[VAL_26:.*]] = memref.alloc() : memref<4xf64> -// CHECK: %[[VAL_27:.*]] = memref.alloc() : memref<4xi1> -// CHECK: %[[VAL_28:.*]] = memref.alloc() : memref<4xindex> -// CHECK: %[[VAL_29:.*]] = memref.cast %[[VAL_28]] : memref<4xindex> to memref -// CHECK: linalg.fill ins(%[[VAL_11]] : f64) outs(%[[VAL_26]] : memref<4xf64>) -// CHECK: linalg.fill ins(%[[VAL_14]] : i1) outs(%[[VAL_27]] : memref<4xi1>) -// CHECK: %[[VAL_30:.*]]:5 = scf.for %[[VAL_31:.*]] = %[[VAL_12]] to %[[VAL_10]] step %[[VAL_13]] iter_args(%[[VAL_32:.*]] = %[[VAL_16]], %[[VAL_33:.*]] = %[[VAL_17]], %[[VAL_34:.*]] = %[[VAL_25]], %[[VAL_35:.*]] = %[[VAL_21]], %[[VAL_36:.*]] = %[[VAL_23]]) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_31]]] : memref -// CHECK: %[[VAL_38:.*]] = arith.addi %[[VAL_31]], %[[VAL_13]] : index -// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_38]]] : memref -// CHECK: %[[VAL_40:.*]] = scf.for %[[VAL_41:.*]] = %[[VAL_37]] to %[[VAL_39]] step %[[VAL_13]] iter_args(%[[VAL_42:.*]] = %[[VAL_12]]) -> (index) { -// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_41]]] : memref -// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_41]]] : memref -// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_43]]] : memref -// CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_43]], %[[VAL_13]] : index -// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_46]]] : memref -// CHECK: %[[VAL_48:.*]] = scf.for %[[VAL_49:.*]] = %[[VAL_45]] to %[[VAL_47]] step %[[VAL_13]] iter_args(%[[VAL_50:.*]] = %[[VAL_42]]) -> (index) { -// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_49]]] : memref -// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_51]]] : memref<4xf64> -// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_49]]] : memref -// CHECK: %[[VAL_54:.*]] = arith.mulf %[[VAL_44]], %[[VAL_53]] : f64 -// CHECK: %[[VAL_55:.*]] = arith.addf %[[VAL_52]], %[[VAL_54]] : f64 -// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_51]]] : memref<4xi1> -// CHECK: %[[VAL_57:.*]] = arith.cmpi eq, %[[VAL_56]], %[[VAL_14]] : i1 -// CHECK: %[[VAL_58:.*]] = scf.if %[[VAL_57]] -> (index) { -// CHECK: memref.store %[[VAL_15]], %[[VAL_27]]{{\[}}%[[VAL_51]]] : memref<4xi1> -// CHECK: memref.store %[[VAL_51]], %[[VAL_28]]{{\[}}%[[VAL_50]]] : memref<4xindex> -// CHECK: %[[VAL_59:.*]] = arith.addi %[[VAL_50]], %[[VAL_13]] : index -// CHECK: scf.yield %[[VAL_59]] : index +// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<16xf64> +// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<16xf64> to memref +// CHECK: %[[VAL_22:.*]] = llvm.mlir.undef : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_23:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_22]][1, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_24:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_23]][1, 1] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_25:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_24]][1, 2] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_26:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_25]][0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_27:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_26]][0, 1] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_28:.*]], %[[VAL_29:.*]] = sparse_tensor.push_back %[[VAL_9]], %[[VAL_17]], %[[VAL_9]] : index, memref, index +// CHECK: %[[VAL_30:.*]] = arith.index_cast %[[VAL_29]] : index to i64 +// CHECK: %[[VAL_31:.*]] = llvm.insertvalue %[[VAL_30]], %[[VAL_27]][1, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_32:.*]], %[[VAL_33:.*]] = sparse_tensor.push_back %[[VAL_29]], %[[VAL_28]], %[[VAL_9]], %[[VAL_8]] : index, memref, index, index +// CHECK: %[[VAL_34:.*]] = arith.index_cast %[[VAL_33]] : index to i64 +// CHECK: %[[VAL_35:.*]] = llvm.insertvalue %[[VAL_34]], %[[VAL_31]][1, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_36:.*]] = memref.alloc() : memref<4xf64> +// CHECK: %[[VAL_37:.*]] = memref.alloc() : memref<4xi1> +// CHECK: %[[VAL_38:.*]] = memref.alloc() : memref<4xindex> +// CHECK: %[[VAL_39:.*]] = memref.cast %[[VAL_38]] : memref<4xindex> to memref +// CHECK: linalg.fill ins(%[[VAL_11]] : f64) outs(%[[VAL_36]] : memref<4xf64>) +// CHECK: linalg.fill ins(%[[VAL_14]] : i1) outs(%[[VAL_37]] : memref<4xi1>) +// CHECK: %[[VAL_40:.*]]:4 = scf.for %[[VAL_41:.*]] = %[[VAL_9]] to %[[VAL_8]] step %[[VAL_13]] iter_args(%[[VAL_42:.*]] = %[[VAL_32]], %[[VAL_43:.*]] = %[[VAL_19]], %[[VAL_44:.*]] = %[[VAL_21]], %[[VAL_45:.*]] = %[[VAL_35]]) -> (memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)>) { +// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_41]]] : memref +// CHECK: %[[VAL_47:.*]] = arith.addi %[[VAL_41]], %[[VAL_13]] : index +// CHECK: %[[VAL_48:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_47]]] : memref +// CHECK: %[[VAL_49:.*]] = scf.for %[[VAL_50:.*]] = %[[VAL_46]] to %[[VAL_48]] step %[[VAL_13]] iter_args(%[[VAL_51:.*]] = %[[VAL_9]]) -> (index) { +// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_50]]] : memref +// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_50]]] : memref +// CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_52]]] : memref +// CHECK: %[[VAL_55:.*]] = arith.addi %[[VAL_52]], %[[VAL_13]] : index +// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_55]]] : memref +// CHECK: %[[VAL_57:.*]] = scf.for %[[VAL_58:.*]] = %[[VAL_54]] to %[[VAL_56]] step %[[VAL_13]] iter_args(%[[VAL_59:.*]] = %[[VAL_51]]) -> (index) { +// CHECK: %[[VAL_60:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_61:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_60]]] : memref<4xf64> +// CHECK: %[[VAL_62:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_63:.*]] = arith.mulf %[[VAL_53]], %[[VAL_62]] : f64 +// CHECK: %[[VAL_64:.*]] = arith.addf %[[VAL_61]], %[[VAL_63]] : f64 +// CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_37]]{{\[}}%[[VAL_60]]] : memref<4xi1> +// CHECK: %[[VAL_66:.*]] = arith.cmpi eq, %[[VAL_65]], %[[VAL_14]] : i1 +// CHECK: %[[VAL_67:.*]] = scf.if %[[VAL_66]] -> (index) { +// CHECK: memref.store %[[VAL_15]], %[[VAL_37]]{{\[}}%[[VAL_60]]] : memref<4xi1> +// CHECK: memref.store %[[VAL_60]], %[[VAL_38]]{{\[}}%[[VAL_59]]] : memref<4xindex> +// CHECK: %[[VAL_68:.*]] = arith.addi %[[VAL_59]], %[[VAL_13]] : index +// CHECK: scf.yield %[[VAL_68]] : index // CHECK: } else { -// CHECK: scf.yield %[[VAL_50]] : index +// CHECK: scf.yield %[[VAL_59]] : index // CHECK: } -// CHECK: memref.store %[[VAL_55]], %[[VAL_26]]{{\[}}%[[VAL_51]]] : memref<4xf64> -// CHECK: scf.yield %[[VAL_60:.*]] : index -// CHECK: } {"Emitted from" = "linalg.generic"} -// CHECK: sparse_tensor.sort %[[VAL_62:.*]], %[[VAL_29]] : memref -// CHECK: %[[VAL_63:.*]]:5 = scf.for %[[VAL_64:.*]] = %[[VAL_12]] to %[[VAL_62]] step %[[VAL_13]] iter_args(%[[VAL_65:.*]] = %[[VAL_32]], %[[VAL_66:.*]] = %[[VAL_33]], %[[VAL_67:.*]] = %[[VAL_34]], %[[VAL_68:.*]] = %[[VAL_35]], %[[VAL_69:.*]] = %[[VAL_36]]) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: %[[VAL_70:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_64]]] : memref<4xindex> -// CHECK: %[[VAL_71:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64> -// CHECK: %[[VAL_72:.*]]:5 = func.call @_insert_D_C_4_4_f64_0_0(%[[VAL_65]], %[[VAL_66]], %[[VAL_67]], %[[VAL_68]], %[[VAL_69]], %[[VAL_31]], %[[VAL_70]], %[[VAL_71]]) : (memref<2xindex>, memref<3xindex>, memref, memref, memref, index, index, f64) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) -// CHECK: memref.store %[[VAL_11]], %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64> -// CHECK: memref.store %[[VAL_14]], %[[VAL_27]]{{\[}}%[[VAL_70]]] : memref<4xi1> -// CHECK: scf.yield %[[VAL_72]]#0, %[[VAL_72]]#1, %[[VAL_72]]#2, %[[VAL_72]]#3, %[[VAL_72]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK: memref.store %[[VAL_64]], %[[VAL_36]]{{\[}}%[[VAL_60]]] : memref<4xf64> +// CHECK: scf.yield %[[VAL_69:.*]] : index +// CHECK: } +// CHECK: scf.yield %[[VAL_70:.*]] : index // CHECK: } -// CHECK: scf.yield %[[VAL_73:.*]]#0, %[[VAL_73]]#1, %[[VAL_73]]#2, %[[VAL_73]]#3, %[[VAL_73]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref -// CHECK: } {"Emitted from" = "linalg.generic"} -// CHECK: memref.dealloc %[[VAL_26]] : memref<4xf64> -// CHECK: memref.dealloc %[[VAL_27]] : memref<4xi1> -// CHECK: memref.dealloc %[[VAL_28]] : memref<4xindex> -// CHECK: %[[VAL_74:.*]] = memref.load %[[VAL_75:.*]]#1{{\[}}%[[VAL_12]]] : memref<3xindex> -// CHECK: %[[VAL_76:.*]] = memref.load %[[VAL_75]]#2{{\[}}%[[VAL_12]]] : memref -// CHECK: %[[VAL_77:.*]] = scf.for %[[VAL_78:.*]] = %[[VAL_13]] to %[[VAL_74]] step %[[VAL_13]] iter_args(%[[VAL_79:.*]] = %[[VAL_76]]) -> (index) { -// CHECK: %[[VAL_80:.*]] = memref.load %[[VAL_75]]#2{{\[}}%[[VAL_78]]] : memref -// CHECK: %[[VAL_81:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_12]] : index -// CHECK: %[[VAL_82:.*]] = arith.select %[[VAL_81]], %[[VAL_79]], %[[VAL_80]] : index -// CHECK: scf.if %[[VAL_81]] { -// CHECK: memref.store %[[VAL_79]], %[[VAL_75]]#2{{\[}}%[[VAL_78]]] : memref +// CHECK: sparse_tensor.sort %[[VAL_71:.*]], %[[VAL_39]] : memref +// CHECK: %[[VAL_72:.*]]:4 = scf.for %[[VAL_73:.*]] = %[[VAL_9]] to %[[VAL_71]] step %[[VAL_13]] iter_args(%[[VAL_74:.*]] = %[[VAL_42]], %[[VAL_75:.*]] = %[[VAL_43]], %[[VAL_76:.*]] = %[[VAL_44]], %[[VAL_77:.*]] = %[[VAL_45]]) -> (memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)>) { +// CHECK: %[[VAL_78:.*]] = memref.load %[[VAL_38]]{{\[}}%[[VAL_73]]] : memref<4xindex> +// CHECK: %[[VAL_79:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_78]]] : memref<4xf64> +// CHECK: %[[VAL_80:.*]]:4 = func.call @_insert_D_C_4_4_f64_0_0(%[[VAL_74]], %[[VAL_75]], %[[VAL_76]], %[[VAL_77]], %[[VAL_41]], %[[VAL_78]], %[[VAL_79]]) : (memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)>, index, index, f64) -> (memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)>) +// CHECK: memref.store %[[VAL_11]], %[[VAL_36]]{{\[}}%[[VAL_78]]] : memref<4xf64> +// CHECK: memref.store %[[VAL_14]], %[[VAL_37]]{{\[}}%[[VAL_78]]] : memref<4xi1> +// CHECK: scf.yield %[[VAL_80]]#0, %[[VAL_80]]#1, %[[VAL_80]]#2, %[[VAL_80]]#3 : memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: } +// CHECK: scf.yield %[[VAL_81:.*]]#0, %[[VAL_81]]#1, %[[VAL_81]]#2, %[[VAL_81]]#3 : memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: } +// CHECK: memref.dealloc %[[VAL_36]] : memref<4xf64> +// CHECK: memref.dealloc %[[VAL_37]] : memref<4xi1> +// CHECK: memref.dealloc %[[VAL_38]] : memref<4xindex> +// CHECK: %[[VAL_82:.*]] = llvm.extractvalue %[[VAL_83:.*]]#3[1, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_84:.*]] = arith.index_cast %[[VAL_82]] : i64 to index +// CHECK: %[[VAL_85:.*]] = memref.load %[[VAL_83]]#0{{\[}}%[[VAL_9]]] : memref +// CHECK: %[[VAL_86:.*]] = scf.for %[[VAL_87:.*]] = %[[VAL_13]] to %[[VAL_84]] step %[[VAL_13]] iter_args(%[[VAL_88:.*]] = %[[VAL_85]]) -> (index) { +// CHECK: %[[VAL_89:.*]] = memref.load %[[VAL_83]]#0{{\[}}%[[VAL_87]]] : memref +// CHECK: %[[VAL_90:.*]] = arith.cmpi eq, %[[VAL_89]], %[[VAL_9]] : index +// CHECK: %[[VAL_91:.*]] = arith.select %[[VAL_90]], %[[VAL_88]], %[[VAL_89]] : index +// CHECK: scf.if %[[VAL_90]] { +// CHECK: memref.store %[[VAL_88]], %[[VAL_83]]#0{{\[}}%[[VAL_87]]] : memref // CHECK: } -// CHECK: scf.yield %[[VAL_82]] : index +// CHECK: scf.yield %[[VAL_91]] : index // CHECK: } -// CHECK: return %[[VAL_75]]#0, %[[VAL_75]]#1, %[[VAL_75]]#2, %[[VAL_75]]#3, %[[VAL_75]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK: return %[[VAL_83]]#0, %[[VAL_83]]#1, %[[VAL_83]]#2, %[[VAL_83]]#3 : memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)> // CHECK: } func.func @matmul(%A: tensor<4x8xf64, #CSR>, %B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> { diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir @@ -24,16 +24,15 @@ %buffer = memref.alloc(%c1) : memref memref.store %c0, %bufferSizes[%c0] : memref - %buffer2 = sparse_tensor.push_back %bufferSizes, %buffer, %d2 {idx=0 : index} : memref, memref, f32 - %buffer3 = sparse_tensor.push_back %bufferSizes, %buffer2, %d1, %c10 {idx=0 : index} : memref, memref, f32, index + %buffer2, %s0 = sparse_tensor.push_back %c0, %buffer, %d2 : index, memref, f32 + %buffer3, %s1 = sparse_tensor.push_back %s0, %buffer2, %d1, %c10 : index, memref, f32, index // CHECK: 16 %capacity = memref.dim %buffer3, %c0 : memref vector.print %capacity : index - // CHECK: ( 11 ) - %size = vector.transfer_read %bufferSizes[%c0], %c0: memref, vector<1xindex> - vector.print %size : vector<1xindex> + // CHECK: 11 + vector.print %s1 : index // CHECK ( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) %values = vector.transfer_read %buffer3[%c0], %d0: memref, vector<11xf32>