diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -25,7 +25,9 @@ namespace bufferization { struct OneShotBufferizationOptions; } // namespace bufferization - +namespace sparse_tensor { +class SparseTensorTypeToBufferConverter; +} // namespace sparse_tensor //===----------------------------------------------------------------------===// // The Sparsification pass. //===----------------------------------------------------------------------===// @@ -122,16 +124,10 @@ // The SparseTensorCodegen pass. //===----------------------------------------------------------------------===// -/// Sparse tensor type converter into an actual buffer. -class SparseTensorTypeToBufferConverter : public TypeConverter { -public: - SparseTensorTypeToBufferConverter(); -}; - /// Sets up sparse tensor conversion rules. -void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, - RewritePatternSet &patterns, - bool enableBufferInitialization); +void populateSparseTensorCodegenPatterns( + sparse_tensor::SparseTensorTypeToBufferConverter &typeConverter, + RewritePatternSet &patterns, bool enableBufferInitialization); std::unique_ptr createSparseTensorCodegenPass(); std::unique_ptr diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -209,6 +209,7 @@ let dependentDialects = [ "arith::ArithDialect", "bufferization::BufferizationDialect", + "LLVM::LLVMDialect", "linalg::LinalgDialect", "memref::MemRefDialect", "scf::SCFDialect", diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ BufferizableOpInterfaceImpl.cpp CodegenUtils.cpp SparseBufferRewriting.cpp + SparseTensorBuilder.cpp SparseTensorCodegen.cpp SparseTensorConversion.cpp SparseTensorPasses.cpp diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -310,220 +310,6 @@ return !rtp || rtp.getRank() == 0; } -//===----------------------------------------------------------------------===// -// SparseTensorDescriptor and helpers, manage the sparse tensor memory layout -// scheme. -// -// Sparse tensor storage scheme for rank-dimensional tensor is organized -// as a single compound type with the following fields. Note that every -// memref with ? size actually behaves as a "vector", i.e. the stored -// size is the capacity and the used size resides in the memSizes array. -// -// struct { -// memref dimSizes ; size in each dimension -// memref memSizes ; sizes of ptrs/inds/values -// ; per-dimension d: -// ; if dense: -// -// ; if compresed: -// memref pointers-d ; pointers for sparse dim d -// memref indices-d ; indices for sparse dim d -// ; if singleton: -// memref indices-d ; indices for singleton dim d -// memref values ; values -// }; -// -//===----------------------------------------------------------------------===// -enum class SparseTensorFieldKind { - DimSizes, - MemSizes, - PtrMemRef, - IdxMemRef, - ValMemRef -}; - -constexpr uint64_t dimSizesIdx = 0; -constexpr uint64_t memSizesIdx = dimSizesIdx + 1; -constexpr uint64_t dataFieldIdx = memSizesIdx + 1; - -/// For each field that will be allocated for the given sparse tensor encoding, -/// calls the callback with the corresponding field index, field kind, dimension -/// (for sparse tensor level memrefs) and dimlevelType. -/// The field index always starts with zero and increments by one between two -/// callback invocations. -/// Ideally, all other methods should rely on this function to query a sparse -/// tensor fields instead of relying on ad-hoc index computation. -void foreachFieldInSparseTensor( - SparseTensorEncodingAttr, - llvm::function_ref); - -/// Same as above, except that it also builds the Type for the corresponding -/// field. -void foreachFieldAndTypeInSparseTensor( - RankedTensorType, - llvm::function_ref); - -/// Gets the total number of fields for the given sparse tensor encoding. -unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc); - -/// Gets the total number of data fields (index arrays, pointer arrays, and a -/// value array) for the given sparse tensor encoding. -unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc); - -/// Get the index of the field in memSizes (only valid for data fields). -inline unsigned getFieldMemSizesIndex(unsigned fid) { - assert(fid >= dataFieldIdx); - return fid - dataFieldIdx; -} - -template -struct SparseTensorValueArrayRef; - -// Uses ValueRange for immuatable descriptors; uses SmallVectorImpl & -// for mutable descriptors. -template <> -struct SparseTensorValueArrayRef { - using ValueArray = ValueRange; -}; - -// Using SmallVector for mutable descriptor allows users to reuse it as a tmp -// buffers to append value for some special cases, though users should be -// responsible to restore the buffer to legal states after their use. It is -// probably not a clean way, but it is the most efficient way to avoid copying -// the fields into another SmallVector. If a more clear way is wanted, we -// should change it to MutableArrayRef instead. -template <> -struct SparseTensorValueArrayRef { - using ValueArray = SmallVectorImpl &; -}; - -/// A helper class around an array of values that corresponding to a sparse -/// tensor, provides a set of meaningful APIs to query and update a particular -/// field in a consistent way. -/// Users should not make assumption on how a sparse tensor is laid out but -/// instead relies on this class to access the right value for the right field. -template -class SparseTensorDescriptorImpl { -private: - using Storage = typename SparseTensorValueArrayRef::ValueArray; - -public: - SparseTensorDescriptorImpl(Type tp, Storage fields) - : rType(tp.cast()), fields(fields) { - assert(getSparseTensorEncoding(tp) && - getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) == - fields.size()); - // We should make sure the class is trivially copyable (and should be small - // enough) such that we can pass it by value. - static_assert( - std::is_trivially_copyable_v>); - } - - // Implicit (and cheap) type conversion from MutSparseTensorDescriptor to - // SparseTensorDescriptor. - template > - /*implicit*/ SparseTensorDescriptorImpl(std::enable_if_t &mDesc) - : rType(mDesc.getTensorType()), fields(mDesc.getFields()) {} - - /// - /// Getters: get the field index for required field. - /// - - unsigned getPtrMemRefIndex(unsigned ptrDim) const { - return getFieldIndex(ptrDim, SparseTensorFieldKind::PtrMemRef); - } - - unsigned getIdxMemRefIndex(unsigned idxDim) const { - return getFieldIndex(idxDim, SparseTensorFieldKind::IdxMemRef); - } - - unsigned getValMemRefIndex() const { return fields.size() - 1; } - - unsigned getPtrMemSizesIndex(unsigned dim) const { - return getPtrMemRefIndex(dim) - dataFieldIdx; - } - - unsigned getIdxMemSizesIndex(unsigned dim) const { - return getIdxMemRefIndex(dim) - dataFieldIdx; - } - - unsigned getValMemSizesIndex() const { - return getValMemRefIndex() - dataFieldIdx; - } - - unsigned getNumFields() const { return fields.size(); } - - /// - /// Getters: get the value for required field. - /// - - Value getDimSizesMemRef() const { return fields[dimSizesIdx]; } - Value getMemSizesMemRef() const { return fields[memSizesIdx]; } - - Value getPtrMemRef(unsigned ptrDim) const { - return fields[getPtrMemRefIndex(ptrDim)]; - } - - Value getIdxMemRef(unsigned idxDim) const { - return fields[getIdxMemRefIndex(idxDim)]; - } - - Value getValMemRef() const { return fields[getValMemRefIndex()]; } - - Value getField(unsigned fid) const { - assert(fid < fields.size()); - return fields[fid]; - } - - /// - /// Setters: update the value for required field (only enabled for - /// MutSparseTensorDescriptor). - /// - - template - void setField(unsigned fid, std::enable_if_t v) { - assert(fid < fields.size()); - fields[fid] = v; - } - - RankedTensorType getTensorType() const { return rType; } - Storage getFields() const { return fields; } - - Type getElementType(unsigned fidx) const { - return fields[fidx].getType().template cast().getElementType(); - } - -private: - unsigned getFieldIndex(unsigned dim, SparseTensorFieldKind kind) const { - unsigned fieldIdx = -1u; - foreachFieldInSparseTensor( - getSparseTensorEncoding(rType), - [dim, kind, &fieldIdx](unsigned fIdx, SparseTensorFieldKind fKind, - unsigned fDim, DimLevelType dlt) -> bool { - if (fDim == dim && kind == fKind) { - fieldIdx = fIdx; - // Returns false to break the iteration. - return false; - } - return true; - }); - assert(fieldIdx != -1u); - return fieldIdx; - } - - RankedTensorType rType; - Storage fields; -}; - -using SparseTensorDescriptor = SparseTensorDescriptorImpl; -using MutSparseTensorDescriptor = SparseTensorDescriptorImpl; - //===----------------------------------------------------------------------===// // SparseTensorLoopEmiter class, manages sparse tensors and helps to // generate loop structure to (co)-iterate sparse tensors. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -90,116 +90,6 @@ return val; } -void sparse_tensor::foreachFieldInSparseTensor( - const SparseTensorEncodingAttr enc, - llvm::function_ref - callback) { - assert(enc); - -#define RETURN_ON_FALSE(idx, kind, dim, dlt) \ - if (!(callback(idx, kind, dim, dlt))) \ - return; - - RETURN_ON_FALSE(dimSizesIdx, SparseTensorFieldKind::DimSizes, -1u, - DimLevelType::Undef); - RETURN_ON_FALSE(memSizesIdx, SparseTensorFieldKind::MemSizes, -1u, - DimLevelType::Undef); - - static_assert(dataFieldIdx == memSizesIdx + 1); - unsigned fieldIdx = dataFieldIdx; - // Per-dimension storage. - for (unsigned r = 0, rank = enc.getDimLevelType().size(); r < rank; r++) { - // Dimension level types apply in order to the reordered dimension. - // As a result, the compound type can be constructed directly in the given - // order. - auto dlt = getDimLevelType(enc, r); - if (isCompressedDLT(dlt)) { - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt); - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); - } else if (isSingletonDLT(dlt)) { - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); - } else { - assert(isDenseDLT(dlt)); // no fields - } - } - // The values array. - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u, - DimLevelType::Undef); - -#undef RETURN_ON_FALSE -} - -void sparse_tensor::foreachFieldAndTypeInSparseTensor( - RankedTensorType rType, - llvm::function_ref - callback) { - auto enc = getSparseTensorEncoding(rType); - assert(enc); - // Construct the basic types. - Type indexType = IndexType::get(enc.getContext()); - Type idxType = enc.getIndexType(); - Type ptrType = enc.getPointerType(); - Type eltType = rType.getElementType(); - unsigned rank = rType.getShape().size(); - // memref dimSizes - Type dimSizeType = MemRefType::get({rank}, indexType); - // memref memSizes - Type memSizeType = - MemRefType::get({getNumDataFieldsFromEncoding(enc)}, indexType); - // memref pointers - Type ptrMemType = MemRefType::get({ShapedType::kDynamic}, ptrType); - // memref indices - Type idxMemType = MemRefType::get({ShapedType::kDynamic}, idxType); - // memref values - Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType); - - foreachFieldInSparseTensor( - enc, - [dimSizeType, memSizeType, ptrMemType, idxMemType, valMemType, - callback](unsigned fieldIdx, SparseTensorFieldKind fieldKind, - unsigned dim, DimLevelType dlt) -> bool { - switch (fieldKind) { - case SparseTensorFieldKind::DimSizes: - return callback(dimSizeType, fieldIdx, fieldKind, dim, dlt); - case SparseTensorFieldKind::MemSizes: - return callback(memSizeType, fieldIdx, fieldKind, dim, dlt); - case SparseTensorFieldKind::PtrMemRef: - return callback(ptrMemType, fieldIdx, fieldKind, dim, dlt); - case SparseTensorFieldKind::IdxMemRef: - return callback(idxMemType, fieldIdx, fieldKind, dim, dlt); - case SparseTensorFieldKind::ValMemRef: - return callback(valMemType, fieldIdx, fieldKind, dim, dlt); - }; - llvm_unreachable("unrecognized field kind"); - }); -} - -unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { - unsigned numFields = 0; - foreachFieldInSparseTensor(enc, - [&numFields](unsigned, SparseTensorFieldKind, - unsigned, DimLevelType) -> bool { - numFields++; - return true; - }); - return numFields; -} - -unsigned -sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) { - unsigned numFields = 0; // one value memref - foreachFieldInSparseTensor(enc, - [&numFields](unsigned fidx, SparseTensorFieldKind, - unsigned, DimLevelType) -> bool { - if (fidx >= dataFieldIdx) - numFields++; - return true; - }); - assert(numFields == getNumFieldsFromEncoding(enc) - dataFieldIdx); - return numFields; -} //===----------------------------------------------------------------------===// // Sparse tensor loop emitter class implementations //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.h @@ -0,0 +1,309 @@ +//===- SparseTensorBuilder.h ------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines utilities for lowering and access sparse tensor +// types. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_ +#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_ + +#include "mlir/Conversion/LLVMCommon/StructBuilder.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace sparse_tensor { + +class SparseTensorTypeToBufferConverter : public TypeConverter { +public: + SparseTensorTypeToBufferConverter(unsigned indexWidth); + + /// Gets the bitwidth of the index type when converted to LLVM. + unsigned getIndexTypeBitwidth() const { return indexWidth; } + + Optional + convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl &fields); + +private: + unsigned indexWidth; +}; + +class SparseTensorMetaData : public StructBuilder { +public: + explicit SparseTensorMetaData(Value metadata) : StructBuilder(metadata) { + assert(value); // value inherented from StructBuilder must have been set. + } + + static SparseTensorMetaData undef(OpBuilder &builder, Location loc, + Type metaType); + static Type getMetaDataIndexType(Value data); + + Value dimSize(OpBuilder &builder, Location loc, unsigned dim); + void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value size); +}; + +//===----------------------------------------------------------------------===// +// SparseTensorDescriptor and helpers, manage the sparse tensor memory layout +// scheme. +// +// Sparse tensor storage scheme for rank-dimensional tensor is organized +// as a single compound type with the following fields. Note that every +// memref with ? size actually behaves as a "vector", i.e. the stored +// size is the capacity and the used size resides in the memSizes array. +// +// struct { +// ; sparse tensor metadata +// struct { +// array dimSizes ; size in each dimension +// } +// memref memSizes ; sizes of ptrs/inds/values +// ; per-dimension d: +// ; if dense: +// +// ; if compresed: +// memref pointers-d ; pointers for sparse dim d +// memref indices-d ; indices for sparse dim d +// ; if singleton: +// memref indices-d ; indices for singleton dim d +// memref values ; values +// }; +// +//===----------------------------------------------------------------------===// +enum class SparseTensorFieldKind { + MetaData, + MemSizes, + PtrMemRef, + IdxMemRef, + ValMemRef +}; + +constexpr uint64_t kMemSizesIdx = 0; +constexpr uint64_t kDataFieldIdx = kMemSizesIdx + 1; + +/// For each field that will be allocated for the given sparse tensor encoding, +/// calls the callback with the corresponding field index, field kind, dimension +/// (for sparse tensor level memrefs) and dimlevelType. +/// The field index always starts with zero and increments by one between two +/// callback invocations. +/// Ideally, all other methods should rely on this function to query a sparse +/// tensor fields instead of relying on ad-hoc index computation. +void foreachFieldInSparseTensor( + SparseTensorEncodingAttr, + llvm::function_ref); + +/// Same as above, except that it also builds the Type for the corresponding +/// field. +void foreachFieldAndTypeInSparseTensor( + RankedTensorType, unsigned, + llvm::function_ref); + +/// Gets the total number of fields for the given sparse tensor encoding. +unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc); + +/// Gets the total number of data fields (index arrays, pointer arrays, and a +/// value array) for the given sparse tensor encoding. +unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc); + +/// Get the index of the field in memSizes (only valid for data fields). +inline unsigned getFieldMemSizesIndex(unsigned fid) { + assert(fid >= kDataFieldIdx); + return fid - kDataFieldIdx; +} + +/// A helper class around an array of values that corresponding to a sparse +/// tensor, provides a set of meaningful APIs to query and update a particular +/// field in a consistent way. +/// Users should not make assumption on how a sparse tensor is laid out but +/// instead relies on this class to access the right value for the right field. +template +class SparseTensorDescriptorImpl { +private: + // Uses ValueRange for immuatable descriptors; uses SmallVectorImpl & + // for mutable descriptors. + // Using SmallVector for mutable descriptor allows users to reuse it as a tmp + // buffers to append value for some special cases, though users should be + // responsible to restore the buffer to legal states after their use. It is + // probably not a clean way, but it is the most efficient way to avoid copying + // the fields into another SmallVector. If a more clear way is wanted, we + // should change it to MutableArrayRef instead. + using ValueArrayRef = typename std::conditional &, + ValueRange>::type; + +public: + SparseTensorDescriptorImpl(Type tp, ValueArrayRef fields) + : indexType(SparseTensorMetaData::getMetaDataIndexType(fields.back())), + rType(tp.cast()), fields(fields) { + assert(getSparseTensorEncoding(tp) && + getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) == + fields.size()); + // We should make sure the class is trivially copyable (and should be small + // enough) such that we can pass it by value. + static_assert( + std::is_trivially_copyable_v>); + } + + // Implicit (and cheap) type conversion from MutSparseTensorDescriptor to + // SparseTensorDescriptor. + template > + /*implicit*/ SparseTensorDescriptorImpl(std::enable_if_t &mDesc) + : indexType(mDesc.getIndexType()), rType(mDesc.getTensorType()), + fields(mDesc.getFields()) {} + + /// + /// Getters: get the field index for required field. + /// + + unsigned getPtrMemRefIndex(unsigned ptrDim) const { + return getFieldIndex(ptrDim, SparseTensorFieldKind::PtrMemRef); + } + + unsigned getIdxMemRefIndex(unsigned idxDim) const { + return getFieldIndex(idxDim, SparseTensorFieldKind::IdxMemRef); + } + + unsigned getValMemRefIndex() const { return fields.size() - 2; } + + unsigned getPtrMemSizesIndex(unsigned dim) const { + return getPtrMemRefIndex(dim) - kDataFieldIdx; + } + + unsigned getIdxMemSizesIndex(unsigned dim) const { + return getIdxMemRefIndex(dim) - kDataFieldIdx; + } + + unsigned getValMemSizesIndex() const { + return getValMemRefIndex() - kDataFieldIdx; + } + + unsigned getNumFields() const { return fields.size(); } + + /// + /// Getters: get the value for required field. + /// + + Value getDimSize(OpBuilder &builder, Location loc, unsigned dim) const { + SparseTensorMetaData md(fields.back()); + return md.dimSize(builder, loc, dim); + } + + Value getMemSizesMemRef() const { return fields[kMemSizesIdx]; } + + Value getPtrMemRef(unsigned ptrDim) const { + return fields[getPtrMemRefIndex(ptrDim)]; + } + + Value getIdxMemRef(unsigned idxDim) const { + return fields[getIdxMemRefIndex(idxDim)]; + } + + Value getValMemRef() const { return fields[getValMemRefIndex()]; } + + Value getField(unsigned fid) const { + assert(fid < fields.size()); + return fields[fid]; + } + + /// + /// Setters: update the value for required field (only enabled for + /// MutSparseTensorDescriptor). + /// + + template + void setField(unsigned fid, std::enable_if_t v) { + assert(fid < fields.size()); + fields[fid] = v; + } + + template + void setDimSize(OpBuilder &builder, Location loc, unsigned dim, + std::enable_if_t v) { + assert(v.getType() == getIndexType()); + SparseTensorMetaData md(fields.back()); + md.setDimSize(builder, loc, dim, v); + fields.back() = md; + } + + ValueRange getMemRefFields() const { + ValueRange ret = fields; + // drop the last metadata fields + return ret.slice(0, fields.size() - 1); + } + + RankedTensorType getTensorType() const { return rType; } + ValueArrayRef getFields() const { return fields; } + Type getIndexType() const { return indexType; } + Type getElementType(unsigned fidx) const { + return fields[fidx].getType().template cast().getElementType(); + } + +private: + unsigned getFieldIndex(unsigned dim, SparseTensorFieldKind kind) const { + unsigned fieldIdx = -1u; + foreachFieldInSparseTensor( + getSparseTensorEncoding(rType), + [dim, kind, &fieldIdx](unsigned fIdx, SparseTensorFieldKind fKind, + unsigned fDim, DimLevelType dlt) -> bool { + if (fDim == dim && kind == fKind) { + fieldIdx = fIdx; + // Returns false to break the iteration. + return false; + } + return true; + }); + assert(fieldIdx != -1u); + return fieldIdx; + } + + Type indexType; + RankedTensorType rType; + ValueArrayRef fields; +}; + +using SparseTensorDescriptor = SparseTensorDescriptorImpl; +using MutSparseTensorDescriptor = SparseTensorDescriptorImpl; + +/// Returns the "tuple" value of the adapted tensor. +inline UnrealizedConversionCastOp getTuple(Value tensor) { + return llvm::cast(tensor.getDefiningOp()); +} + +/// Packs the given values as a "tuple" value. +inline Value genTuple(OpBuilder &builder, Location loc, Type tp, + ValueRange values) { + return builder.create(loc, TypeRange(tp), values) + .getResult(0); +} + +inline Value genTuple(OpBuilder &builder, Location loc, + SparseTensorDescriptor desc) { + return genTuple(builder, loc, desc.getTensorType(), desc.getFields()); +} + +inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) { + auto tuple = getTuple(tensor); + return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs()); +} + +inline MutSparseTensorDescriptor +getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl &fields) { + auto tuple = getTuple(tensor); + fields.assign(tuple.getInputs().begin(), tuple.getInputs().end()); + return MutSparseTensorDescriptor(tuple.getResultTypes()[0], fields); +} + +} // namespace sparse_tensor +} // namespace mlir +#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_ diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.cpp @@ -0,0 +1,216 @@ +//===- SparseTensorBuilder.cpp --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "SparseTensorBuilder.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace sparse_tensor; + +static SmallVector +getSparseTensorMetaDataFields(RankedTensorType rtp, unsigned indexBitwidth) { + MLIRContext *ctx = rtp.getContext(); + SmallVector result; + + auto dimSizes = LLVM::LLVMArrayType::get( + ctx, IntegerType::get(ctx, indexBitwidth), rtp.getRank()); + result.push_back(dimSizes); + return result; +} + +static Type convertSparseTensorMetaData(RankedTensorType rtp, + unsigned indexBitwidth) { + return LLVM::LLVMStructType::getLiteral( + rtp.getContext(), getSparseTensorMetaDataFields(rtp, indexBitwidth)); +} + +Optional +SparseTensorTypeToBufferConverter::convertSparseTensorType( + RankedTensorType rtp, SmallVectorImpl &fields) { + auto enc = getSparseTensorEncoding(rtp); + if (!enc) + return std::nullopt; + + foreachFieldAndTypeInSparseTensor( + rtp, getIndexTypeBitwidth(), + [&fields](Type fieldType, unsigned fieldIdx, + SparseTensorFieldKind /*fieldKind*/, unsigned /*dim*/, + DimLevelType /*dlt*/) -> bool { + assert(fieldIdx == fields.size()); + fields.push_back(fieldType); + return true; + }); + return success(); +} + +SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter( + unsigned indexWidth) + : indexWidth(indexWidth) { + addConversion([](Type type) { return type; }); + addConversion([&](RankedTensorType rtp, SmallVectorImpl &fields) { + return convertSparseTensorType(rtp, fields); + }); + + // Required by scf.for 1:N type conversion. + addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp, + ValueRange inputs, + Location loc) -> Optional { + if (!getSparseTensorEncoding(tp)) + // Not a sparse tensor. + return std::nullopt; + // Sparse compiler knows how to cancel out these casts. + return genTuple(builder, loc, tp, inputs); + }); +} + +// MetaData +constexpr uint64_t kDimSizePosInMetaData = 0; + +SparseTensorMetaData SparseTensorMetaData::undef(OpBuilder &builder, + Location loc, Type metaType) { + Value metaData = builder.create(loc, metaType); + return SparseTensorMetaData(metaData); +} + +Type SparseTensorMetaData::getMetaDataIndexType(Value data) { + return data.getType() + .cast() + .getBody()[kDimSizePosInMetaData] + .cast() + .getElementType(); +} + +/// Builds IR inserting the pos-th size into the descriptor +void SparseTensorMetaData::setDimSize(OpBuilder &builder, Location loc, + unsigned dim, Value size) { + value = builder.create( + loc, value, size, ArrayRef({kDimSizePosInMetaData, dim})); +} + +/// Builds IR inserting the pos-th size into the descriptor +Value SparseTensorMetaData::dimSize(OpBuilder &builder, Location loc, + unsigned dim) { + return builder.create( + loc, value, ArrayRef({kDimSizePosInMetaData, dim})); +} + +void sparse_tensor::foreachFieldInSparseTensor( + const SparseTensorEncodingAttr enc, + llvm::function_ref + callback) { + assert(enc); + +#define RETURN_ON_FALSE(idx, kind, dim, dlt) \ + if (!(callback(idx, kind, dim, dlt))) \ + return; + + RETURN_ON_FALSE(kMemSizesIdx, SparseTensorFieldKind::MemSizes, -1u, + DimLevelType::Undef); + + static_assert(kDataFieldIdx == kMemSizesIdx + 1); + unsigned fieldIdx = kDataFieldIdx; + // Per-dimension storage. + for (unsigned r = 0, rank = enc.getDimLevelType().size(); r < rank; r++) { + // Dimension level types apply in order to the reordered dimension. + // As a result, the compound type can be constructed directly in the given + // order. + auto dlt = getDimLevelType(enc, r); + if (isCompressedDLT(dlt)) { + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt); + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); + } else if (isSingletonDLT(dlt)) { + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); + } else { + assert(isDenseDLT(dlt)); // no fields + } + } + // The values array. + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u, + DimLevelType::Undef); + + // Put metadata at the end. + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::MetaData, -1u, + DimLevelType::Undef); + +#undef RETURN_ON_FALSE +} + +void sparse_tensor::foreachFieldAndTypeInSparseTensor( + RankedTensorType rType, unsigned indexBitwidth, + llvm::function_ref + callback) { + auto enc = getSparseTensorEncoding(rType); + assert(enc); + // Construct the basic types. + Type indexType = IndexType::get(enc.getContext()); + Type idxType = enc.getIndexType(); + Type ptrType = enc.getPointerType(); + Type eltType = rType.getElementType(); + + Type metaDataType = convertSparseTensorMetaData(rType, indexBitwidth); + // memref memSizes + Type memSizeType = + MemRefType::get({getNumDataFieldsFromEncoding(enc)}, indexType); + // memref pointers + Type ptrMemType = MemRefType::get({ShapedType::kDynamic}, ptrType); + // memref indices + Type idxMemType = MemRefType::get({ShapedType::kDynamic}, idxType); + // memref values + Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType); + + foreachFieldInSparseTensor( + enc, + [metaDataType, memSizeType, ptrMemType, idxMemType, valMemType, + callback](unsigned fieldIdx, SparseTensorFieldKind fieldKind, + unsigned dim, DimLevelType dlt) -> bool { + switch (fieldKind) { + case SparseTensorFieldKind::MetaData: + return callback(metaDataType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::MemSizes: + return callback(memSizeType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::PtrMemRef: + return callback(ptrMemType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::IdxMemRef: + return callback(idxMemType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::ValMemRef: + return callback(valMemType, fieldIdx, fieldKind, dim, dlt); + }; + llvm_unreachable("unrecognized field kind"); + }); +} + +unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { + unsigned numFields = 0; + foreachFieldInSparseTensor(enc, + [&numFields](unsigned, SparseTensorFieldKind, + unsigned, DimLevelType) -> bool { + numFields++; + return true; + }); + return numFields; +} + +unsigned +sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) { + unsigned numFields = 0; // one value memref + foreachFieldInSparseTensor(enc, + [&numFields](unsigned fidx, SparseTensorFieldKind, + unsigned, DimLevelType) -> bool { + if (fidx >= kDataFieldIdx) + numFields++; + return true; + }); + numFields -= 1; // the last field is MetaData field + assert(numFields == getNumFieldsFromEncoding(enc) - kDataFieldIdx - 1); + return numFields; +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -16,6 +16,7 @@ //===----------------------------------------------------------------------===// #include "CodegenUtils.h" +#include "SparseTensorBuilder.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -40,38 +41,6 @@ // Helper methods. //===----------------------------------------------------------------------===// -/// Returns the "tuple" value of the adapted tensor. -static UnrealizedConversionCastOp getTuple(Value tensor) { - return llvm::cast(tensor.getDefiningOp()); -} - -static SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) { - auto tuple = getTuple(tensor); - return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs()); -} - -static MutSparseTensorDescriptor -getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl &fields) { - auto tuple = getTuple(tensor); - fields.assign(tuple.getInputs().begin(), tuple.getInputs().end()); - return MutSparseTensorDescriptor(tuple.getResultTypes()[0], fields); -} - -/// Packs the given values as a "tuple" value. -static Value genTuple(OpBuilder &builder, Location loc, Type tp, - ValueRange values) { - return builder.create(loc, TypeRange(tp), values) - .getResult(0); -} - -static Value genTuple(OpBuilder &builder, Location loc, - SparseTensorDescriptor desc) { - return builder - .create(loc, desc.getTensorType(), - desc.getFields()) - .getResult(0); -} - /// Flatten a list of operands that may contain sparse tensors. static void flattenOperands(ValueRange operands, SmallVectorImpl &flattened) { @@ -145,9 +114,9 @@ // Any other query can consult the dimSizes array at field DimSizesIdx, // accounting for the reordering applied to the sparse storage. - Value idx = constantIndex(builder, loc, toStoredDim(rtp, dim)); - return builder.create(loc, desc.getDimSizesMemRef(), idx) - .getResult(); + return toType(builder, loc, + desc.getDimSize(builder, loc, toStoredDim(rtp, dim)), + builder.getIndexType()); } // Gets the dimension size at the given stored dimension 'd', either as a @@ -160,8 +129,8 @@ if (!ShapedType::isDynamic(shape[dim])) return constantIndex(builder, loc, shape[dim]); - return genLoad(builder, loc, desc.getDimSizesMemRef(), - constantIndex(builder, loc, d)); + return toType(builder, loc, desc.getDimSize(builder, loc, d), + builder.getIndexType()); } static void createPushback(OpBuilder &builder, Location loc, @@ -176,26 +145,6 @@ desc.setField(fidx, newField); } -/// Maps a sparse tensor type to the appropriate compounded buffers. -static Optional -convertSparseTensorType(Type type, SmallVectorImpl &fields) { - auto enc = getSparseTensorEncoding(type); - if (!enc) - return std::nullopt; - - RankedTensorType rType = type.cast(); - foreachFieldAndTypeInSparseTensor( - rType, - [&fields](Type fieldType, unsigned fieldIdx, - SparseTensorFieldKind /*fieldKind*/, unsigned /*dim*/, - DimLevelType /*dlt*/) -> bool { - assert(fieldIdx == fields.size()); - fields.push_back(fieldType); - return true; - }); - return success(); -} - /// Generates code that allocates a sparse storage scheme for given rank. static void allocSchemeForRank(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, unsigned r0) { @@ -250,28 +199,31 @@ /// static void createAllocFields(OpBuilder &builder, Location loc, Type type, ValueRange dynSizes, bool enableInit, - SmallVectorImpl &fields) { + SmallVectorImpl &fields, + unsigned indexBitwidth) { RankedTensorType rtp = type.cast(); Value heuristic = constantIndex(builder, loc, 16); foreachFieldAndTypeInSparseTensor( - rtp, + rtp, indexBitwidth, [&builder, &fields, loc, heuristic, enableInit](Type fType, unsigned fIdx, SparseTensorFieldKind fKind, unsigned /*dim*/, DimLevelType /*dlt*/) -> bool { assert(fields.size() == fIdx); - auto memRefTp = fType.cast(); Value field; switch (fKind) { - case SparseTensorFieldKind::DimSizes: + case SparseTensorFieldKind::MetaData: + field = SparseTensorMetaData::undef(builder, loc, fType); + break; case SparseTensorFieldKind::MemSizes: - field = builder.create(loc, memRefTp); + field = + builder.create(loc, fType.cast()); break; case SparseTensorFieldKind::PtrMemRef: case SparseTensorFieldKind::IdxMemRef: case SparseTensorFieldKind::ValMemRef: - field = - createAllocation(builder, loc, memRefTp, heuristic, enableInit); + field = createAllocation(builder, loc, fType.cast(), + heuristic, enableInit); break; } assert(field); @@ -305,8 +257,8 @@ for (unsigned r = 0; r < rank; r++) { unsigned ro = toOrigDim(rtp, r); // Fills dim sizes array. - genStore(builder, loc, sizes[ro], desc.getDimSizesMemRef(), - constantIndex(builder, loc, r)); + desc.setDimSize(builder, loc, r, + toType(builder, loc, sizes[ro], desc.getIndexType())); // Pushes a leading zero to pointers memref. if (isCompressedDim(rtp, r)) @@ -691,7 +643,8 @@ if (!sz) return failure(); - rewriter.replaceOp(op, *sz); + rewriter.replaceOp( + op, toType(rewriter, op.getLoc(), *sz, rewriter.getIndexType())); return success(); } }; @@ -719,9 +672,10 @@ public: using OpConversionPattern::OpConversionPattern; SparseTensorAllocConverter(TypeConverter &typeConverter, MLIRContext *context, - bool enableInit) + bool enableInit, unsigned indexBitwidth) : OpConversionPattern(typeConverter, context), - enableBufferInitialization(enableInit) {} + enableBufferInitialization(enableInit), indexBitwidth(indexBitwidth) {} + LogicalResult matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -736,7 +690,7 @@ Location loc = op.getLoc(); SmallVector fields; createAllocFields(rewriter, loc, resType, adaptor.getOperands(), - enableBufferInitialization, fields); + enableBufferInitialization, fields, indexBitwidth); // Replace operation with resulting memrefs. rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields)); return success(); @@ -744,6 +698,7 @@ private: bool enableBufferInitialization; + unsigned indexBitwidth; }; /// Sparse codegen rule for the dealloc operator. @@ -760,8 +715,8 @@ // Replace the sparse tensor deallocation with field deallocations. Location loc = op.getLoc(); - auto tuple = getTuple(adaptor.getTensor()); - for (auto input : tuple.getInputs()) + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + for (auto input : desc.getMemRefFields()) // Deallocate every buffer used to store the sparse tensor handler. rewriter.create(loc, input); @@ -1027,26 +982,6 @@ } // namespace -//===----------------------------------------------------------------------===// -// Sparse tensor type conversion into an actual buffer. -//===----------------------------------------------------------------------===// - -mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { - addConversion([](Type type) { return type; }); - addConversion(convertSparseTensorType); - - // Required by scf.for 1:N type conversion. - addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp, - ValueRange inputs, - Location loc) -> Optional { - if (!getSparseTensorEncoding(tp)) - // Not a sparse tensor. - return std::nullopt; - // Sparse compiler knows how to cancel out these casts. - return genTuple(builder, loc, tp, inputs); - }); -} - //===----------------------------------------------------------------------===// // Public method for populating conversion rules. //===----------------------------------------------------------------------===// @@ -1054,8 +989,8 @@ /// Populates the given patterns list with conversion rules required for /// the sparsification of linear algebra operations. void mlir::populateSparseTensorCodegenPatterns( - TypeConverter &typeConverter, RewritePatternSet &patterns, - bool enableBufferInitialization) { + SparseTensorTypeToBufferConverter &typeConverter, + RewritePatternSet &patterns, bool enableBufferInitialization) { patterns.add(typeConverter, patterns.getContext()); - patterns.add(typeConverter, patterns.getContext(), - enableBufferInitialization); + + patterns.add( + typeConverter, patterns.getContext(), enableBufferInitialization, + typeConverter.getIndexTypeBitwidth()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -5,6 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +#include "SparseTensorBuilder.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -186,7 +187,8 @@ void runOnOperation() override { auto *ctx = &getContext(); RewritePatternSet patterns(ctx); - SparseTensorTypeToBufferConverter converter; + // TODO: make the index type bitwidth an option. + SparseTensorTypeToBufferConverter converter(64u); ConversionTarget target(*ctx); // Most ops in the sparse dialect must go! target.addIllegalDialect(); @@ -216,9 +218,10 @@ // The following operations and dialects may be introduced by the // codegen rules, and are therefore marked as legal. target.addLegalOp(); - target.addLegalDialect< - arith::ArithDialect, bufferization::BufferizationDialect, - complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>(); + target.addLegalDialect(); target.addLegalOp(); // Populate with rules and apply rewriting rules. populateFunctionOpInterfaceTypeConversionPattern(patterns, diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -47,31 +47,31 @@ }> // CHECK-LABEL: func @sparse_nop( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) +// CHECK-SAME: %[[A0:.*]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<1 x i64>)>) // CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)> func.func @sparse_nop(%arg0: tensor) -> tensor { return %arg0 : tensor } // CHECK-LABEL: func @sparse_nop_multi_ret( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref<1xindex>, -// CHECK-SAME: %[[A6:.*6]]: memref<3xindex>, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: memref, -// CHECK-SAME: %[[A9:.*9]]: memref) +// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<1 x i64>)>, +// CHECK-SAME: %[[A5:.*5]]: memref<3xindex>, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: memref, +// CHECK-SAME: %[[A8:.*8]]: memref, +// CHECK-SAME: %[[A9:.*9]]: !llvm.struct<(array<1 x i64>)>) // CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]] : -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref, -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)>, +// CHECK-SAME: memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)> func.func @sparse_nop_multi_ret(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { @@ -79,20 +79,20 @@ } // CHECK-LABEL: func @sparse_nop_call( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref<1xindex>, -// CHECK-SAME: %[[A6:.*6]]: memref<3xindex>, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: memref, -// CHECK-SAME: %[[A9:.*9]]: memref) +// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<1 x i64>)>, +// CHECK-SAME: %[[A5:.*5]]: memref<3xindex>, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: memref, +// CHECK-SAME: %[[A8:.*8]]: memref, +// CHECK-SAME: %[[A9:.*9]]: !llvm.struct<(array<1 x i64>)>) // CHECK: %[[T:.*]]:10 = call @sparse_nop_multi_ret(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]]) // CHECK: return %[[T]]#0, %[[T]]#1, %[[T]]#2, %[[T]]#3, %[[T]]#4, %[[T]]#5, %[[T]]#6, %[[T]]#7, %[[T]]#8, %[[T]]#9 : -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref, -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)>, +// CHECK-SAME: memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)> func.func @sparse_nop_call(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { @@ -103,68 +103,67 @@ } // CHECK-LABEL: func @sparse_nop_cast( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) +// CHECK-SAME: %[[A0:.*]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<1 x i64>)>) // CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor { %0 = tensor.cast %arg0 : tensor<64xf32, #SparseVector> to tensor return %0 : tensor } // CHECK-LABEL: func @sparse_nop_cast_3d( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref) +// CHECK-SAME: %[[A0:.*]]: memref<1xindex>, +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: !llvm.struct<(array<3 x i64>)>) // CHECK: return %[[A0]], %[[A1]], %[[A2]] : -// CHECK-SAME: memref<3xindex>, memref<1xindex>, memref +// CHECK-SAME: memref<1xindex>, memref, !llvm.struct<(array<3 x i64>)> func.func @sparse_nop_cast_3d(%arg0: tensor<10x20x30xf32, #Dense3D>) -> tensor { %0 = tensor.cast %arg0 : tensor<10x20x30xf32, #Dense3D> to tensor return %0 : tensor } // CHECK-LABEL: func @sparse_dense_2d( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref) +// CHECK-SAME: %[[A0:.*]]: memref<1xindex>, +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: !llvm.struct<(array<2 x i64>)>) { // CHECK: return func.func @sparse_dense_2d(%arg0: tensor) { return } // CHECK-LABEL: func @sparse_row( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) +// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<2 x i64>)>) { // CHECK: return func.func @sparse_row(%arg0: tensor) { return } // CHECK-LABEL: func @sparse_csr( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) +// CHECK-SAME: %[[A0:.*]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<2 x i64>)>) { // CHECK: return func.func @sparse_csr(%arg0: tensor) { return } // CHECK-LABEL: func @sparse_dcsr( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref) +// CHECK-SAME: %[[A0:.*0]]: memref<5xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: !llvm.struct<(array<2 x i64>)>) // CHECK: return func.func @sparse_dcsr(%arg0: tensor) { return @@ -175,9 +174,9 @@ // fold using the original static dimension sizes. // // CHECK-LABEL: func @sparse_dense_3d( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref) +// CHECK-SAME: %[[A0:.*]]: memref<1xindex>, +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: !llvm.struct<(array<3 x i64>)>) // CHECK: %[[C:.*]] = arith.constant 20 : index // CHECK: return %[[C]] : index func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index { @@ -192,12 +191,12 @@ // since the latter honors the dimOrdering. // // CHECK-LABEL: func @sparse_dense_3d_dyn( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref) -// CHECK: %[[C:.*]] = arith.constant 2 : index -// CHECK: %[[L:.*]] = memref.load %[[A0]][%[[C]]] : memref<3xindex> -// CHECK: return %[[L]] : index +// CHECK-SAME: %[[A0:.*]]: memref<1xindex>, +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: !llvm.struct<(array<3 x i64>)>) +// CHECK: %[[A3:.*]] = llvm.extractvalue %[[A2]][0, 2] +// CHECK: %[[A4:.*]] = arith.index_cast %[[A3]] : i64 to index +// CHECK: return %[[A4]] : index func.func @sparse_dense_3d_dyn(%arg0: tensor) -> index { %c = arith.constant 1 : index %0 = tensor.dim %arg0, %c : tensor @@ -205,55 +204,55 @@ } // CHECK-LABEL: func @sparse_pointers_dcsr( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref) -// CHECK: return %[[A4]] : memref +// CHECK-SAME: %[[A0:.*0]]: memref<5xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: !llvm.struct<(array<2 x i64>)> +// CHECK: return %[[A3]] : memref func.func @sparse_pointers_dcsr(%arg0: tensor) -> memref { %0 = sparse_tensor.pointers %arg0 { dimension = 1 : index } : tensor to memref return %0 : memref } // CHECK-LABEL: func @sparse_indices_dcsr( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref) -// CHECK: return %[[A5]] : memref +// CHECK-SAME: %[[A0:.*0]]: memref<5xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: !llvm.struct<(array<2 x i64>)> +// CHECK: return %[[A4]] : memref func.func @sparse_indices_dcsr(%arg0: tensor) -> memref { %0 = sparse_tensor.indices %arg0 { dimension = 1 : index } : tensor to memref return %0 : memref } // CHECK-LABEL: func @sparse_values_dcsr( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref) -// CHECK: return %[[A6]] : memref +// CHECK-SAME: %[[A0:.*0]]: memref<5xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: !llvm.struct<(array<2 x i64>)>) +// CHECK: return %[[A5]] : memref func.func @sparse_values_dcsr(%arg0: tensor) -> memref { %0 = sparse_tensor.values %arg0 : tensor to memref return %0 : memref } // CHECK-LABEL: func @sparse_noe( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) +// CHECK-SAME: %[[A0:.*]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<1 x i64>)>) // CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[NOE:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex> +// CHECK: %[[NOE:.*]] = memref.load %[[A0]][%[[C2]]] : memref<3xindex> // CHECK: return %[[NOE]] : index func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index { %0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector> @@ -261,16 +260,15 @@ } // CHECK-LABEL: func @sparse_dealloc_csr( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) -// CHECK: memref.dealloc %[[A0]] : memref<2xindex> -// CHECK: memref.dealloc %[[A1]] : memref<3xindex> -// CHECK: memref.dealloc %[[A2]] : memref -// CHECK: memref.dealloc %[[A3]] : memref -// CHECK: memref.dealloc %[[A4]] : memref +// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<2 x i64>)>) +// CHECK: memref.dealloc %[[A0]] : memref<3xindex> +// CHECK: memref.dealloc %[[A1]] : memref +// CHECK: memref.dealloc %[[A2]] : memref +// CHECK: memref.dealloc %[[A3]] : memref // CHECK: return func.func @sparse_dealloc_csr(%arg0: tensor) { bufferization.dealloc_tensor %arg0 : tensor @@ -278,53 +276,47 @@ } // CHECK-LABEL: func @sparse_alloc_csc( -// CHECK-SAME: %[[A:.*]]: index) -> -// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref, memref, memref -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index -// CHECK: %[[T0:.*]] = memref.alloc() : memref<2xindex> -// CHECK: %[[T1:.*]] = memref.alloc() : memref<3xindex> -// CHECK: %[[T2:.*]] = memref.alloc() : memref<16xindex> -// CHECK: %[[T3:.*]] = memref.cast %[[T2]] : memref<16xindex> to memref -// CHECK: %[[T4:.*]] = memref.alloc() : memref<16xindex> -// CHECK: %[[T5:.*]] = memref.cast %[[T4]] : memref<16xindex> to memref -// CHECK: %[[T6:.*]] = memref.alloc() : memref<16xf64> -// CHECK: %[[T7:.*]] = memref.cast %[[T6]] : memref<16xf64> to memref -// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T1]] : memref<3xindex>) -// CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<2xindex> -// CHECK: memref.store %[[C10]], %[[T0]][%[[C1]]] : memref<2xindex> -// CHECK: %[[P0:.*]] = sparse_tensor.push_back %[[T1]], %[[T3]] -// CHECK: %[[P1:.*]] = sparse_tensor.push_back %[[T1]], %[[P0]] -// CHECK: return %[[T0]], %[[T1]], %[[P1]], %[[T5]], %[[T7]] : -// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A0:.*]]: index) +// CHECK-DAG: %[[A1:.*]] = arith.constant 10 : i64 +// CHECK-DAG: %[[A2:.*]] = arith.constant 0 : index +// CHECK: %[[A3:.*]] = memref.alloc() : memref<3xindex> +// CHECK: %[[A4:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[A5:.*]] = memref.cast %[[A4]] : memref<16xindex> to memref +// CHECK: %[[A6:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[A7:.*]] = memref.cast %[[A6]] : memref<16xindex> to memref +// CHECK: %[[A8:.*]] = memref.alloc() : memref<16xf64> +// CHECK: %[[A9:.*]] = memref.cast %[[A8]] : memref<16xf64> to memref +// CHECK: %[[A10:.*]] = llvm.mlir.undef : !llvm.struct<(array<2 x i64>)> +// CHECK: linalg.fill ins(%[[A2]] : index) outs(%[[A3]] : memref<3xindex>) +// CHECK: %[[A11:.*]] = arith.index_cast %[[A0]] : index to i64 +// CHECK: %[[A12:.*]] = llvm.insertvalue %[[A11]], %[[A10]][0, 0] +// CHECK: %[[A13:.*]] = llvm.insertvalue %[[A1]], %[[A12]][0, 1] +// CHECK: %[[A14:.*]] = sparse_tensor.push_back %[[A3]], %[[A5]], %[[A2]] {idx = 0 : index} : memref<3xindex>, memref, index +// CHECK: %[[A15:.*]] = sparse_tensor.push_back %[[A3]], %[[A14]], %[[A2]], %[[A0]] {idx = 0 : index} : memref<3xindex>, memref, index, index +// CHECK: return %[[A3]], %[[A15]], %[[A7]], %[[A9]], %[[A13]] : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)> func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> { %0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSC> %1 = sparse_tensor.load %0 : tensor<10x?xf64, #CSC> return %1 : tensor<10x?xf64, #CSC> } -// CHECK-LABEL: func @sparse_alloc_3d() -> -// CHECK-SAME: memref<3xindex>, memref<1xindex>, memref -// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index -// CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index -// CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index -// CHECK-DAG: %[[C6000:.*]] = arith.constant 6000 : index -// CHECK: %[[A0:.*]] = memref.alloc() : memref<3xindex> -// CHECK: %[[A1:.*]] = memref.alloc() : memref<1xindex> -// CHECK: %[[AV:.*]] = memref.alloc() : memref<16xf64> -// CHECK: %[[A2:.*]] = memref.cast %[[AV]] : memref<16xf64> to memref -// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[A1]] : memref<1xindex>) -// CHECK: memref.store %[[C30]], %[[A0]][%[[C0]]] : memref<3xindex> -// CHECK: memref.store %[[C10]], %[[A0]][%[[C1]]] : memref<3xindex> -// CHECK: memref.store %[[C20]], %[[A0]][%[[C2]]] : memref<3xindex> -// CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[F0]], %[[C6000]] -// CHECK: return %[[A0]], %[[A1]], %[[P]] : -// CHECK-SAME: memref<3xindex>, memref<1xindex>, memref +// CHECK-LABEL: func @sparse_alloc_3d() +// CHECK-DAG: %[[A0:.*]] = arith.constant 6000 : index +// CHECK-DAG: %[[A1:.*]] = arith.constant 20 : i64 +// CHECK-DAG: %[[A2:.*]] = arith.constant 10 : i64 +// CHECK-DAG: %[[A3:.*]] = arith.constant 30 : i64 +// CHECK-DAG: %[[A4:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[A5:.*]] = arith.constant 0 : index +// CHECK: %[[A6:.*]] = memref.alloc() : memref<1xindex> +// CHECK: %[[A7:.*]] = memref.alloc() : memref<16xf64> +// CHECK: %[[A8:.*]] = memref.cast %[[A7]] : memref<16xf64> to memref +// CHECK: %[[A9:.*]] = llvm.mlir.undef : !llvm.struct<(array<3 x i64>)> +// CHECK: linalg.fill ins(%[[A5]] : index) outs(%[[A6]] : memref<1xindex>) +// CHECK: %[[A10:.*]] = llvm.insertvalue %[[A3]], %[[A9]][0, 0] +// CHECK: %[[A11:.*]] = llvm.insertvalue %[[A2]], %[[A10]][0, 1] +// CHECK: %[[A12:.*]] = llvm.insertvalue %[[A1]], %[[A11]][0, 2] +// CHECK: %[[A13:.*]] = sparse_tensor.push_back %[[A6]], %[[A8]], %[[A4]], %[[A0]] {idx = 0 : index} : memref<1xindex>, memref, f64, index +// CHECK: return %[[A6]], %[[A13]], %[[A12]] : memref<1xindex>, memref, !llvm.struct<(array<3 x i64>)> func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> { %0 = bufferization.alloc_tensor() : tensor<10x20x30xf64, #Dense3D> %1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D> @@ -364,13 +356,9 @@ // CHECK-LABEL: func.func @sparse_expansion3( // CHECK-SAME: %[[D0:.*]]: index, // CHECK-SAME: %{{.*}}: index) -> memref { -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[S0:.*]] = memref.alloc() : memref<2xindex> -// CHECK: memref.store %[[D0]], %[[S0]]{{\[}}%[[C1]]] : memref<2xindex> -// CHECK: %[[D1:.*]] = memref.load %[[S0]]{{\[}}%[[C1]]] : memref<2xindex> -// CHECK: %[[V:.*]] = memref.alloc(%[[D1]]) : memref -// CHECK: %[[B:.*]] = memref.alloc(%[[D1]]) : memref -// CHECK: %[[D:.*]] = memref.alloc(%[[D1]]) : memref +// CHECK: %[[V:.*]] = memref.alloc(%[[D0]]) : memref +// CHECK: %[[B:.*]] = memref.alloc(%[[D0]]) : memref +// CHECK: %[[D:.*]] = memref.alloc(%[[D0]]) : memref // CHECK: linalg.fill ins(%{{.*}} : f64) outs(%[[V]] : memref) // CHECK: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref) // CHECK: return %[[D]] : memref @@ -382,45 +370,43 @@ } // CHECK-LABEL: func.func private @_insert_C_100_f64_0_0( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, // CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<1 x i64>)>, // CHECK-SAME: %[[A5:.*5]]: index, // CHECK-SAME: %[[A6:.*6]]: f64) -// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[PV]] +// CHECK: %[[A22:.*]] = sparse_tensor.push_back %[[A0]], %[[A3]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref, f64 +// CHECK: return %[[A0]], %[[A1]], %{{.*}}, %[[A22]], %[[A4]] : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)> // // CHECK-LABEL: func @sparse_compression_1d( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, // CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<1 x i64>)>, // CHECK-SAME: %[[A5:.*5]]: memref, // CHECK-SAME: %[[A6:.*6]]: memref, // CHECK-SAME: %[[A7:.*7]]: memref, // CHECK-SAME: %[[A8:.*8]]: index) -// CHECK-DAG: %[[B0:.*]] = arith.constant false -// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[A9:.*]] = arith.constant false +// CHECK-DAG: %[[A10:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[A11:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[A12:.*]] = arith.constant 0 : index // CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref -// CHECK: %[[R:.*]]:5 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] -// CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<1xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref -// CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref -// CHECK: %[[C:.*]]:5 = func.call @_insert_C_100_f64_0_0(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[INDEX]], %[[VAL]]) -// CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref -// CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref -// CHECK: scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK: %[[A13:.*]]:5 = scf.for %[[A14:.*]] = %[[A12]] to %[[A8]] step %[[A11]] iter_args(%[[A15:.*]] = %[[A0]], %[[A16:.*]] = %[[A1]], %[[A17:.*]] = %[[A2]], %[[A18:.*]] = %[[A3]], %[[A19:.*]] = %[[A4]]) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)>) { +// CHECK: %[[A20:.*]] = memref.load %[[A7]]{{\[}}%[[A14]]] : memref +// CHECK: %[[A21:.*]] = memref.load %[[A5]]{{\[}}%[[A20]]] : memref +// CHECK: %[[A22:.*]]:5 = func.call @_insert_C_100_f64_0_0(%[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A19]], %[[A20]], %[[A21]]) : (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)>, index, f64) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)>) +// CHECK: memref.store %[[A10]], %[[A5]]{{\[}}%[[A20]]] : memref +// CHECK: memref.store %[[A9]], %[[A6]]{{\[}}%[[A20]]] : memref +// CHECK: scf.yield %[[A22]]#0, %[[A22]]#1, %[[A22]]#2, %[[A22]]#3, %[[A22]]#4 : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)> // CHECK: } // CHECK: memref.dealloc %[[A5]] : memref // CHECK: memref.dealloc %[[A6]] : memref // CHECK: memref.dealloc %[[A7]] : memref -// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK: return %[[A23:.*]]#0, %[[A23]]#1, %[[A23]]#2, %[[A23]]#3, %[[A23]]#4 : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)> func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>, %values: memref, %filled: memref, @@ -433,47 +419,46 @@ } // CHECK-LABEL: func.func private @_insert_D_C_8_8_f64_64_32( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<2 x i64>)>, // CHECK-SAME: %[[A5:.*5]]: index, // CHECK-SAME: %[[A6:.*6]]: index, // CHECK-SAME: %[[A7:.*7]]: f64) -// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A7]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[PV]] -// +// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A0]], %[[A3]], %[[A7]] {idx = 2 : index} : memref<3xindex>, memref, f64 +// CHECK: return %[[A0]], %[[A1]], %{{.*}}, %[[PV]], %[[A4]] +/// // CHECK-LABEL: func @sparse_compression( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<2 x i64>)>, // CHECK-SAME: %[[A5:.*5]]: memref, // CHECK-SAME: %[[A6:.*6]]: memref, // CHECK-SAME: %[[A7:.*7]]: memref, // CHECK-SAME: %[[A8:.*8]]: index, // CHECK-SAME: %[[A9:.*9]]: index) -// CHECK-DAG: %[[B0:.*]] = arith.constant false -// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[A10:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[A11:.*]] = arith.constant false +// CHECK-DAG: %[[A12:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[A13:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[A14:.*]] = arith.constant 0 : index // CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref -// CHECK: %[[R:.*]]:5 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] -// CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref -// CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref -// CHECK: %[[C:.*]]:5 = func.call @_insert_D_C_8_8_f64_64_32(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[A9]], %[[INDEX]], %[[VAL]]) -// CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref -// CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref -// CHECK: scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK: %[[R:.*]]:5 = scf.for %[[A16:.*]] = %[[A14]] to %[[A8]] step %[[A13]] iter_args(%[[A17:.*]] = %[[A0]], %[[A18:.*]] = %[[A1]], %[[A19:.*]] = %[[A2]], %[[A20:.*]] = %[[A3]], %[[A21:.*]] = %[[A4]]) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)>) { +// CHECK: %[[A22:.*]] = memref.load %[[A7]]{{\[}}%[[A16]]] : memref +// CHECK: %[[A23:.*]] = memref.load %[[A5]]{{\[}}%[[A22]]] : memref +// CHECK: %[[A24:.*]]:5 = func.call @_insert_D_C_8_8_f64_64_32(%[[A17]], %[[A18]], %[[A19]], %[[A20]], %[[A21]], %[[A9]], %[[A22]], %[[A23]]) +// CHECK: memref.store %[[A12]], %[[A5]]{{\[}}%[[A22]]] : memref +// CHECK: memref.store %[[A11]], %[[A6]]{{\[}}%[[A22]]] : memref +// CHECK: scf.yield %[[A24]]#0, %[[A24]]#1, %[[A24]]#2, %[[A24]]#3, %[[A24]]#4 : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)> // CHECK: } // CHECK: memref.dealloc %[[A5]] : memref // CHECK: memref.dealloc %[[A6]] : memref // CHECK: memref.dealloc %[[A7]] : memref -// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 -// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)> func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>, %values: memref, %filled: memref, @@ -487,47 +472,45 @@ } // CHECK-LABEL: func.func private @_insert_D_C_8_8_f64_0_0( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, // CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<2 x i64>)>, // CHECK-SAME: %[[A5:.*5]]: index, // CHECK-SAME: %[[A6:.*6]]: index, // CHECK-SAME: %[[A7:.*7]]: f64) -// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A7]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[PV]] +// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A0]], %[[A3]], %[[A7]] {idx = 2 : index} : memref<3xindex>, memref, f64 +// CHECK: return %[[A0]], %[[A1]], %{{.*}}, %[[PV]], %[[A4]] // // CHECK-LABEL: func @sparse_compression_unordered( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, // CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<2 x i64>)>, // CHECK-SAME: %[[A5:.*5]]: memref, // CHECK-SAME: %[[A6:.*6]]: memref, // CHECK-SAME: %[[A7:.*7]]: memref, // CHECK-SAME: %[[A8:.*8]]: index, // CHECK-SAME: %[[A9:.*9]]: index) -// CHECK-DAG: %[[B0:.*]] = arith.constant false -// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[A10:.*]] = arith.constant false +// CHECK-DAG: %[[A11:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[A12:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[A13:.*]] = arith.constant 1 : index // CHECK-NOT: sparse_tensor.sort -// CHECK: %[[R:.*]]:5 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] -// CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref -// CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref -// CHECK: %[[C:.*]]:5 = func.call @_insert_D_C_8_8_f64_0_0(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[A9]], %[[INDEX]], %[[VAL]]) -// CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref -// CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref -// CHECK: scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK: %[[R:.*]]:5 = scf.for %[[A15:.*]] = %[[A12]] to %[[A8]] step %[[A13]] iter_args(%[[A16:.*]] = %[[A0]], %[[A17:.*]] = %[[A1]], %[[A18:.*]] = %[[A2]], %[[A19:.*]] = %[[A3]], %[[A20:.*]] = %[[A4]]) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)>) { +// CHECK: %[[A21:.*]] = memref.load %[[A7]]{{\[}}%[[A15]]] : memref +// CHECK: %[[A22:.*]] = memref.load %[[A5]]{{\[}}%[[A21]]] : memref +// CHECK: %[[A23:.*]]:5 = func.call @_insert_D_C_8_8_f64_0_0(%[[A16]], %[[A17]], %[[A18]], %[[A19]], %[[A20]], %[[A9]], %[[A21]], %[[A22]]) : (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)>, index, index, f64) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)>) +// CHECK: memref.store %[[A11]], %[[A5]]{{\[}}%[[A21]]] : memref +// CHECK: memref.store %[[A10]], %[[A6]]{{\[}}%[[A21]]] : memref +// CHECK: scf.yield %[[A23]]#0, %[[A23]]#1, %[[A23]]#2, %[[A23]]#3, %[[A23]]#4 : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)> // CHECK: } // CHECK: memref.dealloc %[[A5]] : memref // CHECK: memref.dealloc %[[A6]] : memref // CHECK: memref.dealloc %[[A7]] : memref -// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 -// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)> func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>, %values: memref, %filled: memref, @@ -541,26 +524,26 @@ } // CHECK-LABEL: func.func private @_insert_C_128_f64_0_0( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, // CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<1 x i64>)>, // CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: f64) -// CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[P]] : -// CHECK: func @sparse_insert( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A6:.*6]]: f64) +// CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A0]], %[[A3]], %[[A6]] {idx = 2 : index} +// CHECK: return %[[A0]], %[[A1]], %{{.*}}, %[[P]], %[[A4]] + +// CHECK-LABEL: func @sparse_insert( +// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, // CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<1 x i64>)>, // CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: f64) +// CHECK-SAME: %[[A6:.*6]]: f64) // CHECK: %[[R:.*]]:5 = call @_insert_C_128_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]]) // CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> { %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV> %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SV> @@ -568,26 +551,26 @@ } // CHECK-LABEL: func.func private @_insert_C_128_f64_64_32( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<1 x i64>)>, // CHECK-SAME: %[[A5:.*5]]: index, // CHECK-SAME: %[[A6:.*6]]: f64) -// CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[P]] : -// CHECK: func @sparse_insert_typed( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A0]], %[[A3]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref, f64 +// CHECK: return %[[A0]], %[[A1]], %{{.*}}, %[[P]], %[[A4]] + +// CHECK-LABEL: func @sparse_insert_typed( +// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<1 x i64>)>, // CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: f64) +// CHECK-SAME: %[[A6:.*6]]: f64) // CHECK: %[[R:.*]]:5 = call @_insert_C_128_f64_64_32(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]]) // CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> { %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector> %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SparseVector> @@ -595,14 +578,14 @@ } // CHECK-LABEL: func.func @sparse_nop_convert( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) +// CHECK-SAME: %[[A0:.*]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<1 x i64>)>) // CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)> func.func @sparse_nop_convert(%arg0: tensor<32xf32, #SparseVector>) -> tensor { %0 = sparse_tensor.convert %arg0 : tensor<32xf32, #SparseVector> to tensor return %0 : tensor -} \ No newline at end of file +} diff --git a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir --- a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir @@ -4,11 +4,10 @@ // CHECK-LABEL: func @sparse_alloc_sparse_vector( // CHECK-SAME: %[[A:.*]]: index) -> -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK: %[[T0:.*]] = memref.alloc() : memref<1xindex> // CHECK: %[[T1:.*]] = memref.alloc() : memref<3xindex> // CHECK: %[[T2:.*]] = memref.alloc() : memref<16xindex> // CHECK: %[[T3:.*]] = memref.cast %[[T2]] : memref<16xindex> to memref @@ -19,11 +18,13 @@ // CHECK: %[[T6:.*]] = memref.alloc() : memref<16xf64> // CHECK: %[[T7:.*]] = memref.cast %[[T6]] : memref<16xf64> to memref // CHECK: linalg.fill ins(%[[F0]] : f64) outs(%[[T6]] : memref<16xf64>) +// CHECK: %[[T11:.*]] = llvm.mlir.undef : !llvm.struct<(array<1 x i64>)> // CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T1]] : memref<3xindex>) -// CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<1xindex> +// CHECK: %[[T12:.*]] = arith.index_cast %[[A]] : index to i64 +// CHECK: %[[MD:.*]] = llvm.insertvalue %[[T12]], %[[T11]][0, 0] // CHECK: %[[P0:.*]] = sparse_tensor.push_back %[[T1]], %[[T3]] // CHECK: %[[P1:.*]] = sparse_tensor.push_back %[[T1]], %[[P0]] -// CHECK: return %[[T0]], %[[T1]], %[[P1]], %[[T5]], %[[T7]] : +// CHECK: return %[[T1]], %[[P1]], %[[T5]], %[[T7]], %[[MD]] func.func @sparse_alloc_sparse_vector(%arg0: index) -> tensor { %0 = bufferization.alloc_tensor(%arg0) : tensor %1 = sparse_tensor.load %0 : tensor diff --git a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir --- a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir @@ -1,24 +1,26 @@ // RUN: mlir-opt %s -sparse-tensor-codegen -cse | FileCheck %s #SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> -// CHECK-LABEL: func @for( -// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[MEM_SIZE:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[POINTER:.*2]]: memref, -// CHECK-SAME: %[[INDICES:.*3]]: memref, -// CHECK-SAME: %[[VALUE:.*4]]: memref, -// CHECK-SAME: %[[LB:.*5]]: index, -// CHECK-SAME: %[[UB:.*6]]: index, -// CHECK-SAME: %[[STEP:.*7]]: index) -// CHECK: %[[OUT:.*]]:5 = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args( -// CHECK-SAME: %[[SIZE:.*]] = %[[DIM_SIZE]], -// CHECK-SAME: %[[MEM:.*]] = %[[MEM_SIZE]], -// CHECK-SAME: %[[PTR:.*]] = %[[POINTER]], -// CHECK-SAME: %[[IDX:.*]] = %[[INDICES]], -// CHECK-SAME: %[[VAL:.*]] = %[[VALUE]]) -// CHECK: scf.yield %[[SIZE]], %[[MEM]], %[[PTR]], %[[IDX]], %[[VAL]] : memref<1xindex>, memref<3xindex>, memref, memref, memref -// CHECK: } -// CHECK: return %[[OUT]]#0, %[[OUT]]#1, %[[OUT]]#2, %[[OUT]]#3, %[[OUT]]#4 : memref<1xindex>, memref<3xindex>, memref, memref, memref + +// CHECK-LABEL: func.func @for( +// CHECK-SAME: %[[VAL_0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[VAL_1:.*1]]: memref, +// CHECK-SAME: %[[VAL_2:.*2]]: memref, +// CHECK-SAME: %[[VAL_3:.*3]]: memref, +// CHECK-SAME: %[[VAL_4:.*4]]: !llvm.struct<(array<1 x i64>)>, +// CHECK-SAME: %[[VAL_5:.*5]]: index, +// CHECK-SAME: %[[VAL_6:.*6]]: index, +// CHECK-SAME: %[[VAL_7:.*7]]: index) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)>) { +// CHECK: %[[VAL_8:.*]]:5 = scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_6]] step %[[VAL_7]] iter_args( +// CHECK-SAME: %[[VAL_10:.*]] = %[[VAL_0]], +// CHECK-SAME: %[[VAL_11:.*]] = %[[VAL_1]], +// CHECK-SAME: %[[VAL_12:.*]] = %[[VAL_2]], +// CHECK-SAME: %[[VAL_13:.*]] = %[[VAL_3]], +// CHECK-SAME: %[[VAL_14:.*]] = %[[VAL_4]]) +// CHECK: scf.yield %[[VAL_10]], %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]] : +// CHECK: } +// CHECK: return %[[VAL_8]]#0, %[[VAL_8]]#1, %[[VAL_8]]#2, %[[VAL_8]]#3, %[[VAL_8]]#4 : +// CHECK: } func.func @for(%in: tensor<1024xf32, #SparseVector>, %lb: index, %ub: index, %step: index) -> tensor<1024xf32, #SparseVector> { %1 = scf.for %i = %lb to %ub step %step iter_args(%vin = %in) @@ -28,26 +30,25 @@ return %1 : tensor<1024xf32, #SparseVector> } - -// CHECK-LABEL: func @if( -// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[MEM_SIZE:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[POINTER:.*2]]: memref, -// CHECK-SAME: %[[INDICES:.*3]]: memref, -// CHECK-SAME: %[[VALUE:.*4]]: memref, -// CHECK-SAME: %[[DIM_SIZE_1:.*5]]: memref<1xindex>, -// CHECK-SAME: %[[MEM_SIZE_1:.*6]]: memref<3xindex>, -// CHECK-SAME: %[[POINTER_1:.*7]]: memref, -// CHECK-SAME: %[[INDICES_1:.*8]]: memref, -// CHECK-SAME: %[[VALUE_1:.*9]]: memref, -// CHECK-SAME: %[[I1:.*10]]: i1) -> -// CHECK-SAME: (memref<1xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: %[[SV:.*]]:5 = scf.if %[[I1]] -> (memref<1xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: scf.yield %[[DIM_SIZE]], %[[MEM_SIZE]], %[[POINTER]], %[[INDICES]], %[[VALUE]] : memref<1xindex>, memref<3xindex>, memref, memref, memref -// CHECK: } else { -// CHECK: scf.yield %[[DIM_SIZE_1]], %[[MEM_SIZE_1]], %[[POINTER_1]], %[[INDICES_1]], %[[VALUE_1]] : memref<1xindex>, memref<3xindex>, memref, memref, memref -// CHECK: } -// CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4 : memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-LABEL: func.func @if( +// CHECK-SAME: %[[VAL_0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[VAL_1:.*1]]: memref, +// CHECK-SAME: %[[VAL_2:.*2]]: memref, +// CHECK-SAME: %[[VAL_3:.*3]]: memref, +// CHECK-SAME: %[[VAL_4:.*4]]: !llvm.struct<(array<1 x i64>)>, +// CHECK-SAME: %[[VAL_5:.*5]]: memref<3xindex>, +// CHECK-SAME: %[[VAL_6:.*6]]: memref, +// CHECK-SAME: %[[VAL_7:.*7]]: memref, +// CHECK-SAME: %[[VAL_8:.*8]]: memref, +// CHECK-SAME: %[[VAL_9:.*9]]: !llvm.struct<(array<1 x i64>)>, +// CHECK-SAME: %[[VAL_10:.*]]: i1) +// CHECK: %[[VAL_11:.*]]:5 = scf.if %[[VAL_10]] +// CHECK: scf.yield %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] +// CHECK: } else { +// CHECK: scf.yield %[[VAL_5]], %[[VAL_6]], %[[VAL_7]], %[[VAL_8]], %[[VAL_9]] +// CHECK: } +// CHECK: return %[[VAL_11]]#0, %[[VAL_11]]#1, %[[VAL_11]]#2, %[[VAL_11]]#3, %[[VAL_11]]#4 : +// CHECK-SAME: memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)> func.func @if(%t: tensor<1024xf32, #SparseVector>, %f: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024xf32, #SparseVector> { @@ -59,26 +60,31 @@ return %1 : tensor<1024xf32, #SparseVector> } -// CHECK-LABEL: func @while( -// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[MEM_SIZE:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[POINTER:.*2]]: memref, -// CHECK-SAME: %[[INDICES:.*3]]: memref, -// CHECK-SAME: %[[VALUE:.*4]]: memref, -// CHECK-SAME: %[[I1:.*5]]: i1) -> -// CHECK-SAME: (memref<1xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: %[[SV:.*]]:5 = scf.while ( -// CHECK-SAME: %[[TMP_DIM:.*]] = %[[DIM_SIZE]], -// CHECK-SAME: %[[TMP_MEM:.*]] = %[[MEM_SIZE]], -// CHECK-SAME: %[[TMP_PTR:.*]] = %[[POINTER]], -// CHECK-SAME: %[[TMP_IND:.*]] = %[[INDICES]], -// CHECK-SAME: %[[TMP_VAL:.*]] = %[[VALUE]]) -// CHECK: scf.condition(%[[I1]]) %[[TMP_DIM]], %[[TMP_MEM]], %[[TMP_PTR]], %[[TMP_IND]], %[[TMP_VAL]] : memref<1xindex>, memref<3xindex>, memref, memref, memref -// CHECK: } do { -// CHECK: ^bb0(%[[TMP_DIM]]: memref<1xindex>, %[[TMP_MEM]]: memref<3xindex>, %[[TMP_PTR]]: memref, %[[TMP_IND]]: memref, %[[TMP_VAL]]: memref): -// CHECK: scf.yield %[[TMP_DIM]], %[[TMP_MEM]], %[[TMP_PTR]], %[[TMP_IND]], %[[TMP_VAL]] : memref<1xindex>, memref<3xindex>, memref, memref, memref -// CHECK: } -// CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4 : memref<1xindex>, memref<3xindex>, memref, memref, memref + +// CHECK-LABEL: func.func @while( +// CHECK-SAME: %[[VAL_0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[VAL_1:.*1]]: memref, +// CHECK-SAME: %[[VAL_2:.*2]]: memref, +// CHECK-SAME: %[[VAL_3:.*3]]: memref, +// CHECK-SAME: %[[VAL_4:.*4]]: !llvm.struct<(array<1 x i64>)>, +// CHECK-SAME: %[[VAL_5:.*5]]: i1) +// CHECK: %[[VAL_6:.*]]:5 = scf.while ( +// CHECK-SAME: %[[VAL_7:.*]] = %[[VAL_0]], +// CHECK-SAME: %[[VAL_8:.*]] = %[[VAL_1]], +// CHECK-SAME: %[[VAL_9:.*]] = %[[VAL_2]], +// CHECK-SAME: %[[VAL_10:.*]] = %[[VAL_3]], +// CHECK-SAME: %[[VAL_11:.*]] = %[[VAL_4]]) +// CHECK: scf.condition(%[[VAL_5]]) %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]], %[[VAL_11]] +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_12:.*6]]: memref<3xindex>, +// CHECK-SAME: %[[VAL_13:.*7]]: memref, +// CHECK-SAME: %[[VAL_14:.*8]]: memref, +// CHECK-SAME: %[[VAL_15:.*9]]: memref, +// CHECK-SAME: %[[VAL_16:.*10]]: !llvm.struct<(array<1 x i64>)>): +// CHECK: scf.yield %[[VAL_12]], %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] +// CHECK: } +// CHECK: return %[[VAL_6]]#0, %[[VAL_6]]#1, %[[VAL_6]]#2, %[[VAL_6]]#3, %[[VAL_6]]#4 : +// CHECK-SAME: memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)> func.func @while(%arg0: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024xf32, #SparseVector> { %0 = scf.while (%in = %arg0) : (tensor<1024xf32, #SparseVector>) -> tensor<1024xf32, #SparseVector> { scf.condition(%c) %in : tensor<1024xf32, #SparseVector> diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir @@ -13,60 +13,60 @@ // Computes C = A x B with all matrices sparse (SpMSpM) in CSR. // // CHECK-LABEL: func.func private @_insert_D_C_4_4_f64_0_0( -// CHECK-SAME: %[[VAL_0:.*]]: memref<2xindex>, -// CHECK-SAME: %[[VAL_1:.*]]: memref<3xindex>, -// CHECK-SAME: %[[VAL_2:[^ ]+]]: memref, -// CHECK-SAME: %[[VAL_3:.*]]: memref, -// CHECK-SAME: %[[VAL_4:.*]]: memref, -// CHECK-SAME: %[[VAL_5:[^ ]+]]: index, -// CHECK-SAME: %[[VAL_6:.*]]: index, -// CHECK-SAME: %[[VAL_7:.*]]: f64) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant false -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 1 : index +// CHECK-SAME: %[[VAL_0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[VAL_1:.*1]]: memref, +// CHECK-SAME: %[[VAL_2:.*2]]: memref, +// CHECK-SAME: %[[VAL_3:.*3]]: memref, +// CHECK-SAME: %[[VAL_4:.*4]]: !llvm.struct<(array<2 x i64>)>, +// CHECK-SAME: %[[VAL_5:.*5]]: index, +// CHECK-SAME: %[[VAL_6:.*6]]: index, +// CHECK-SAME: %[[VAL_7:.*7]]: f64) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)>) { +// CHECK: %[[VAL_8:.*]] = arith.constant false +// CHECK: %[[VAL_9:.*]] = arith.constant 1 : index // CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_5]], %[[VAL_9]] : index -// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref -// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref -// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_9]]] : memref<3xindex> +// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_10]]] : memref +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref<3xindex> // CHECK: %[[VAL_14:.*]] = arith.subi %[[VAL_12]], %[[VAL_9]] : index // CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_12]] : index // CHECK: %[[VAL_16:.*]] = scf.if %[[VAL_15]] -> (i1) { -// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_14]]] : memref +// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_14]]] : memref // CHECK: %[[VAL_18:.*]] = arith.cmpi eq, %[[VAL_17]], %[[VAL_6]] : index // CHECK: scf.yield %[[VAL_18]] : i1 // CHECK: } else { -// CHECK: memref.store %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref +// CHECK: memref.store %[[VAL_13]], %[[VAL_1]]{{\[}}%[[VAL_5]]] : memref // CHECK: scf.yield %[[VAL_8]] : i1 // CHECK: } // CHECK: %[[VAL_19:.*]] = scf.if %[[VAL_20:.*]] -> (memref) { -// CHECK: scf.yield %[[VAL_3]] : memref +// CHECK: scf.yield %[[VAL_2]] : memref // CHECK: } else { // CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_13]], %[[VAL_9]] : index -// CHECK: memref.store %[[VAL_21]], %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref -// CHECK: %[[VAL_22:.*]] = sparse_tensor.push_back %[[VAL_1]], %[[VAL_3]], %[[VAL_6]] {idx = 1 : index} : memref<3xindex>, memref, index +// CHECK: memref.store %[[VAL_21]], %[[VAL_1]]{{\[}}%[[VAL_10]]] : memref +// CHECK: %[[VAL_22:.*]] = sparse_tensor.push_back %[[VAL_0]], %[[VAL_2]], %[[VAL_6]] {idx = 1 : index} : memref<3xindex>, memref, index // CHECK: scf.yield %[[VAL_22]] : memref // CHECK: } -// CHECK: %[[VAL_23:.*]] = sparse_tensor.push_back %[[VAL_1]], %[[VAL_4]], %[[VAL_7]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_24:.*]], %[[VAL_23]] : memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK: %[[VAL_23:.*]] = sparse_tensor.push_back %[[VAL_0]], %[[VAL_3]], %[[VAL_7]] {idx = 2 : index} : memref<3xindex>, memref, f64 +// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_24:.*]], %[[VAL_23]], %[[VAL_4]] : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)> // CHECK: } // CHECK-LABEL: func.func @matmul( -// CHECK-SAME: %[[VAL_0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[VAL_1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[VAL_0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[VAL_1:.*1]]: memref, // CHECK-SAME: %[[VAL_2:.*2]]: memref, -// CHECK-SAME: %[[VAL_3:.*3]]: memref, -// CHECK-SAME: %[[VAL_4:.*4]]: memref, -// CHECK-SAME: %[[VAL_5:.*5]]: memref<2xindex>, -// CHECK-SAME: %[[VAL_6:.*6]]: memref<3xindex>, +// CHECK-SAME: %[[VAL_3:.*3]]: memref, +// CHECK-SAME: %[[VAL_4:.*4]]: !llvm.struct<(array<2 x i64>)>, +// CHECK-SAME: %[[VAL_5:.*5]]: memref<3xindex>, +// CHECK-SAME: %[[VAL_6:.*6]]: memref, // CHECK-SAME: %[[VAL_7:.*7]]: memref, -// CHECK-SAME: %[[VAL_8:.*8]]: memref, -// CHECK-SAME: %[[VAL_9:.*9]]: memref) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_14:.*]] = arith.constant false -// CHECK-DAG: %[[VAL_15:.*]] = arith.constant true -// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<2xindex> +// CHECK-SAME: %[[VAL_8:.*8]]: memref, +// CHECK-SAME: %[[VAL_9:.*9]]: !llvm.struct<(array<2 x i64>)>) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)>) { +// CHECK: %[[VAL_10:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_11:.*]] = arith.constant 4 : i64 +// CHECK: %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[VAL_13:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_14:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_15:.*]] = arith.constant false +// CHECK: %[[VAL_16:.*]] = arith.constant true // CHECK: %[[VAL_17:.*]] = memref.alloc() : memref<3xindex> // CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<16xindex> // CHECK: %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<16xindex> to memref @@ -74,72 +74,75 @@ // CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<16xindex> to memref // CHECK: %[[VAL_22:.*]] = memref.alloc() : memref<16xf64> // CHECK: %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<16xf64> to memref -// CHECK: linalg.fill ins(%[[VAL_12]] : index) outs(%[[VAL_17]] : memref<3xindex>) -// CHECK: memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_12]]] : memref<2xindex> -// CHECK: memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_13]]] : memref<2xindex> -// CHECK: %[[VAL_24:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_19]], %[[VAL_12]] {idx = 0 : index} : memref<3xindex>, memref, index -// CHECK: %[[VAL_25:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_24]], %[[VAL_12]], %[[VAL_10]] {idx = 0 : index} : memref<3xindex>, memref, index, index -// CHECK: %[[VAL_26:.*]] = memref.alloc() : memref<4xf64> -// CHECK: %[[VAL_27:.*]] = memref.alloc() : memref<4xi1> -// CHECK: %[[VAL_28:.*]] = memref.alloc() : memref<4xindex> -// CHECK: %[[VAL_29:.*]] = memref.cast %[[VAL_28]] : memref<4xindex> to memref -// CHECK: linalg.fill ins(%[[VAL_11]] : f64) outs(%[[VAL_26]] : memref<4xf64>) -// CHECK: linalg.fill ins(%[[VAL_14]] : i1) outs(%[[VAL_27]] : memref<4xi1>) -// CHECK: %[[VAL_30:.*]]:5 = scf.for %[[VAL_31:.*]] = %[[VAL_12]] to %[[VAL_10]] step %[[VAL_13]] iter_args(%[[VAL_32:.*]] = %[[VAL_16]], %[[VAL_33:.*]] = %[[VAL_17]], %[[VAL_34:.*]] = %[[VAL_25]], %[[VAL_35:.*]] = %[[VAL_21]], %[[VAL_36:.*]] = %[[VAL_23]]) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_31]]] : memref -// CHECK: %[[VAL_38:.*]] = arith.addi %[[VAL_31]], %[[VAL_13]] : index -// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_38]]] : memref -// CHECK: %[[VAL_40:.*]] = scf.for %[[VAL_41:.*]] = %[[VAL_37]] to %[[VAL_39]] step %[[VAL_13]] iter_args(%[[VAL_42:.*]] = %[[VAL_12]]) -> (index) { -// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_41]]] : memref -// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_41]]] : memref -// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_43]]] : memref -// CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_43]], %[[VAL_13]] : index -// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_46]]] : memref -// CHECK: %[[VAL_48:.*]] = scf.for %[[VAL_49:.*]] = %[[VAL_45]] to %[[VAL_47]] step %[[VAL_13]] iter_args(%[[VAL_50:.*]] = %[[VAL_42]]) -> (index) { -// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_49]]] : memref -// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_51]]] : memref<4xf64> -// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_49]]] : memref -// CHECK: %[[VAL_54:.*]] = arith.mulf %[[VAL_44]], %[[VAL_53]] : f64 -// CHECK: %[[VAL_55:.*]] = arith.addf %[[VAL_52]], %[[VAL_54]] : f64 -// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_51]]] : memref<4xi1> -// CHECK: %[[VAL_57:.*]] = arith.cmpi eq, %[[VAL_56]], %[[VAL_14]] : i1 -// CHECK: %[[VAL_58:.*]] = scf.if %[[VAL_57]] -> (index) { -// CHECK: memref.store %[[VAL_15]], %[[VAL_27]]{{\[}}%[[VAL_51]]] : memref<4xi1> -// CHECK: memref.store %[[VAL_51]], %[[VAL_28]]{{\[}}%[[VAL_50]]] : memref<4xindex> -// CHECK: %[[VAL_59:.*]] = arith.addi %[[VAL_50]], %[[VAL_13]] : index -// CHECK: scf.yield %[[VAL_59]] : index +// CHECK: %[[VAL_24:.*]] = llvm.mlir.undef : !llvm.struct<(array<2 x i64>)> +// CHECK: linalg.fill ins(%[[VAL_13]] : index) outs(%[[VAL_17]] : memref<3xindex>) +// CHECK: %[[VAL_25:.*]] = llvm.insertvalue %[[VAL_11]], %[[VAL_24]][0, 0] : !llvm.struct<(array<2 x i64>)> +// CHECK: %[[VAL_26:.*]] = llvm.insertvalue %[[VAL_11]], %[[VAL_25]][0, 1] : !llvm.struct<(array<2 x i64>)> +// CHECK: %[[VAL_27:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_19]], %[[VAL_13]] {idx = 0 : index} : memref<3xindex>, memref, index +// CHECK: %[[VAL_28:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_27]], %[[VAL_13]], %[[VAL_10]] {idx = 0 : index} : memref<3xindex>, memref, index, index +// CHECK: %[[VAL_29:.*]] = memref.alloc() : memref<4xf64> +// CHECK: %[[VAL_30:.*]] = memref.alloc() : memref<4xi1> +// CHECK: %[[VAL_31:.*]] = memref.alloc() : memref<4xindex> +// CHECK: %[[VAL_32:.*]] = memref.cast %[[VAL_31]] : memref<4xindex> to memref +// CHECK: linalg.fill ins(%[[VAL_12]] : f64) outs(%[[VAL_29]] : memref<4xf64>) +// CHECK: linalg.fill ins(%[[VAL_15]] : i1) outs(%[[VAL_30]] : memref<4xi1>) +// CHECK: %[[VAL_33:.*]]:5 = scf.for %[[VAL_34:.*]] = %[[VAL_13]] to %[[VAL_10]] step %[[VAL_14]] iter_args(%[[VAL_35:.*]] = %[[VAL_17]], %[[VAL_36:.*]] = %[[VAL_28]], %[[VAL_37:.*]] = %[[VAL_21]], %[[VAL_38:.*]] = %[[VAL_23]], %[[VAL_39:.*]] = %[[VAL_26]]) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)>) { +// CHECK: %[[VAL_40:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_34]]] : memref +// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_34]], %[[VAL_14]] : index +// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_41]]] : memref +// CHECK: %[[VAL_43:.*]] = scf.for %[[VAL_44:.*]] = %[[VAL_40]] to %[[VAL_42]] step %[[VAL_14]] iter_args(%[[VAL_45:.*]] = %[[VAL_13]]) -> (index) { +// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_44]]] : memref +// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_44]]] : memref +// CHECK: %[[VAL_48:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_46]]] : memref +// CHECK: %[[VAL_49:.*]] = arith.addi %[[VAL_46]], %[[VAL_14]] : index +// CHECK: %[[VAL_50:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_49]]] : memref +// CHECK: %[[VAL_51:.*]] = scf.for %[[VAL_52:.*]] = %[[VAL_48]] to %[[VAL_50]] step %[[VAL_14]] iter_args(%[[VAL_53:.*]] = %[[VAL_45]]) -> (index) { +// CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_52]]] : memref +// CHECK: %[[VAL_55:.*]] = memref.load %[[VAL_29]]{{\[}}%[[VAL_54]]] : memref<4xf64> +// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_52]]] : memref +// CHECK: %[[VAL_57:.*]] = arith.mulf %[[VAL_47]], %[[VAL_56]] : f64 +// CHECK: %[[VAL_58:.*]] = arith.addf %[[VAL_55]], %[[VAL_57]] : f64 +// CHECK: %[[VAL_59:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_54]]] : memref<4xi1> +// CHECK: %[[VAL_60:.*]] = arith.cmpi eq, %[[VAL_59]], %[[VAL_15]] : i1 +// CHECK: %[[VAL_61:.*]] = scf.if %[[VAL_60]] -> (index) { +// CHECK: memref.store %[[VAL_16]], %[[VAL_30]]{{\[}}%[[VAL_54]]] : memref<4xi1> +// CHECK: memref.store %[[VAL_54]], %[[VAL_31]]{{\[}}%[[VAL_53]]] : memref<4xindex> +// CHECK: %[[VAL_62:.*]] = arith.addi %[[VAL_53]], %[[VAL_14]] : index +// CHECK: scf.yield %[[VAL_62]] : index // CHECK: } else { -// CHECK: scf.yield %[[VAL_50]] : index +// CHECK: scf.yield %[[VAL_53]] : index // CHECK: } -// CHECK: memref.store %[[VAL_55]], %[[VAL_26]]{{\[}}%[[VAL_51]]] : memref<4xf64> -// CHECK: scf.yield %[[VAL_60:.*]] : index -// CHECK: } {"Emitted from" = "linalg.generic"} -// CHECK: sparse_tensor.sort %[[VAL_62:.*]], %[[VAL_29]] : memref -// CHECK: %[[VAL_63:.*]]:5 = scf.for %[[VAL_64:.*]] = %[[VAL_12]] to %[[VAL_62]] step %[[VAL_13]] iter_args(%[[VAL_65:.*]] = %[[VAL_32]], %[[VAL_66:.*]] = %[[VAL_33]], %[[VAL_67:.*]] = %[[VAL_34]], %[[VAL_68:.*]] = %[[VAL_35]], %[[VAL_69:.*]] = %[[VAL_36]]) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: %[[VAL_70:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_64]]] : memref<4xindex> -// CHECK: %[[VAL_71:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64> -// CHECK: %[[VAL_72:.*]]:5 = func.call @_insert_D_C_4_4_f64_0_0(%[[VAL_65]], %[[VAL_66]], %[[VAL_67]], %[[VAL_68]], %[[VAL_69]], %[[VAL_31]], %[[VAL_70]], %[[VAL_71]]) : (memref<2xindex>, memref<3xindex>, memref, memref, memref, index, index, f64) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) -// CHECK: memref.store %[[VAL_11]], %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64> -// CHECK: memref.store %[[VAL_14]], %[[VAL_27]]{{\[}}%[[VAL_70]]] : memref<4xi1> -// CHECK: scf.yield %[[VAL_72]]#0, %[[VAL_72]]#1, %[[VAL_72]]#2, %[[VAL_72]]#3, %[[VAL_72]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK: memref.store %[[VAL_58]], %[[VAL_29]]{{\[}}%[[VAL_54]]] : memref<4xf64> +// CHECK: scf.yield %[[VAL_63:.*]] : index +// CHECK: } +// CHECK: scf.yield %[[VAL_64:.*]] : index +// CHECK: } +// CHECK: sparse_tensor.sort %[[VAL_65:.*]], %[[VAL_32]] : memref +// CHECK: %[[VAL_66:.*]]:5 = scf.for %[[VAL_67:.*]] = %[[VAL_13]] to %[[VAL_65]] step %[[VAL_14]] iter_args(%[[VAL_68:.*]] = %[[VAL_35]], %[[VAL_69:.*]] = %[[VAL_36]], %[[VAL_70:.*]] = %[[VAL_37]], %[[VAL_71:.*]] = %[[VAL_38]], %[[VAL_72:.*]] = %[[VAL_39]]) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)>) { +// CHECK: %[[VAL_73:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_67]]] : memref<4xindex> +// CHECK: %[[VAL_74:.*]] = memref.load %[[VAL_29]]{{\[}}%[[VAL_73]]] : memref<4xf64> +// CHECK: %[[VAL_75:.*]]:5 = func.call @_insert_D_C_4_4_f64_0_0(%[[VAL_68]], %[[VAL_69]], %[[VAL_70]], %[[VAL_71]], %[[VAL_72]], %[[VAL_34]], %[[VAL_73]], %[[VAL_74]]) : (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)>, index, index, f64) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)>) +// CHECK: memref.store %[[VAL_12]], %[[VAL_29]]{{\[}}%[[VAL_73]]] : memref<4xf64> +// CHECK: memref.store %[[VAL_15]], %[[VAL_30]]{{\[}}%[[VAL_73]]] : memref<4xi1> +// CHECK: scf.yield %[[VAL_75]]#0, %[[VAL_75]]#1, %[[VAL_75]]#2, %[[VAL_75]]#3, %[[VAL_75]]#4 : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)> // CHECK: } -// CHECK: scf.yield %[[VAL_73:.*]]#0, %[[VAL_73]]#1, %[[VAL_73]]#2, %[[VAL_73]]#3, %[[VAL_73]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref -// CHECK: } {"Emitted from" = "linalg.generic"} -// CHECK: memref.dealloc %[[VAL_26]] : memref<4xf64> -// CHECK: memref.dealloc %[[VAL_27]] : memref<4xi1> -// CHECK: memref.dealloc %[[VAL_28]] : memref<4xindex> -// CHECK: %[[VAL_74:.*]] = memref.load %[[VAL_75:.*]]#1{{\[}}%[[VAL_12]]] : memref<3xindex> -// CHECK: %[[VAL_76:.*]] = memref.load %[[VAL_75]]#2{{\[}}%[[VAL_12]]] : memref -// CHECK: %[[VAL_77:.*]] = scf.for %[[VAL_78:.*]] = %[[VAL_13]] to %[[VAL_74]] step %[[VAL_13]] iter_args(%[[VAL_79:.*]] = %[[VAL_76]]) -> (index) { -// CHECK: %[[VAL_80:.*]] = memref.load %[[VAL_75]]#2{{\[}}%[[VAL_78]]] : memref -// CHECK: %[[VAL_81:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_12]] : index -// CHECK: %[[VAL_82:.*]] = arith.select %[[VAL_81]], %[[VAL_79]], %[[VAL_80]] : index -// CHECK: scf.if %[[VAL_81]] { -// CHECK: memref.store %[[VAL_79]], %[[VAL_75]]#2{{\[}}%[[VAL_78]]] : memref +// CHECK: scf.yield %[[VAL_76:.*]]#0, %[[VAL_76]]#1, %[[VAL_76]]#2, %[[VAL_76]]#3, %[[VAL_76]]#4 : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)> +// CHECK: } +// CHECK: memref.dealloc %[[VAL_29]] : memref<4xf64> +// CHECK: memref.dealloc %[[VAL_30]] : memref<4xi1> +// CHECK: memref.dealloc %[[VAL_31]] : memref<4xindex> +// CHECK: %[[VAL_77:.*]] = memref.load %[[VAL_78:.*]]#0{{\[}}%[[VAL_13]]] : memref<3xindex> +// CHECK: %[[VAL_79:.*]] = memref.load %[[VAL_78]]#1{{\[}}%[[VAL_13]]] : memref +// CHECK: %[[VAL_80:.*]] = scf.for %[[VAL_81:.*]] = %[[VAL_14]] to %[[VAL_77]] step %[[VAL_14]] iter_args(%[[VAL_82:.*]] = %[[VAL_79]]) -> (index) { +// CHECK: %[[VAL_83:.*]] = memref.load %[[VAL_78]]#1{{\[}}%[[VAL_81]]] : memref +// CHECK: %[[VAL_84:.*]] = arith.cmpi eq, %[[VAL_83]], %[[VAL_13]] : index +// CHECK: %[[VAL_85:.*]] = arith.select %[[VAL_84]], %[[VAL_82]], %[[VAL_83]] : index +// CHECK: scf.if %[[VAL_84]] { +// CHECK: memref.store %[[VAL_82]], %[[VAL_78]]#1{{\[}}%[[VAL_81]]] : memref // CHECK: } -// CHECK: scf.yield %[[VAL_82]] : index +// CHECK: scf.yield %[[VAL_85]] : index // CHECK: } -// CHECK: return %[[VAL_75]]#0, %[[VAL_75]]#1, %[[VAL_75]]#2, %[[VAL_75]]#3, %[[VAL_75]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK: return %[[VAL_78]]#0, %[[VAL_78]]#1, %[[VAL_78]]#2, %[[VAL_78]]#3, %[[VAL_78]]#4 : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)> // CHECK: } func.func @matmul(%A: tensor<4x8xf64, #CSR>, %B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> {