diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -277,17 +277,18 @@ "inBuffer", "value", "$_self.cast().getElementType()">, AllTypesMatch<["inBuffer", "outBuffer"]>]>, - Arguments<(ins StridedMemRefRankOf<[Index], [1]>:$bufferSizes, + Arguments<(ins Index:$curSize, StridedMemRefRankOf<[AnyType], [1]>:$inBuffer, - AnyType:$value, IndexAttr:$idx, Optional:$n, + AnyType:$value, Optional:$n, UnitAttr:$inbounds)>, - Results<(outs StridedMemRefRankOf<[AnyType], [1]>:$outBuffer)> { + Results<(outs StridedMemRefRankOf<[AnyType], [1]>:$outBuffer, + Index:$newSize)> { string summary = "Pushes a value to the back of a given buffer"; string description = [{ Push `value` to the end of the given sparse tensor storage buffer - `inBuffer` and update the size of the buffer in `bufferSizes[idx]`. The - capacity of the buffer is recorded in the memref type of `inBuffer `. If the - current buffer is full, then `inBuffer.realloc` is called before pushing the + `inBuffer` according to `curSize` and return the new size of the buffer in + `newSize`. The capacity of the buffer is recorded in the memref type of `inBuffer`. + If the current buffer is full, then `inBuffer.realloc` is called before pushing the data to the buffer. This is similar to std::vector push_back. The optional input `n` specifies the number of times to repeately push @@ -310,29 +311,28 @@ Example: ```mlir - %r = sparse_tensor.push_back %bufferSizes, %buffer, %val - {idx = 0 : index} : memref, memref, f64 + %buf, %newSize = sparse_tensor.push_back %curSize, %buffer, %val + : index, memref, f64 ``` ```mlir - %r = sparse_tensor.push_back inbounds %bufferSizes, %buffer, %val - {idx = 0 : index} : memref, memref, f64 + %buf, %newSize = sparse_tensor.push_back inbounds %curSize, %buffer, %val + : xindex, memref, f64 ``` ```mlir - %r = sparse_tensor.push_back inbounds %bufferSizes, %buffer, %val, %n - {idx = 0 : index} : memref, memref, f64 + %buf, %newSize = sparse_tensor.push_back inbounds %curSize, %buffer, %val, %n + : xindex, memref, f64 ``` }]; - let assemblyFormat = "(`inbounds` $inbounds^)? $bufferSizes `,` $inBuffer" + let assemblyFormat = "(`inbounds` $inbounds^)? $curSize `,` $inBuffer" " `,` $value (`,` $n^ )? attr-dict `:`" - " type($bufferSizes) `,` type($inBuffer) `,`" - " type($value) (`,` type($n)^ )?"; + " type($curSize) `,` type($inBuffer) `,`" + " type($value) (`,` type($n)^ )?"; let builders = [ - //Build an op without input `n`. - OpBuilder<(ins "Type":$outBuffer, "Value":$bufferSizes, "Value":$inBuffer, - "Value":$value, "APInt":$idx)> + //Build an op (reusing type from curSize and inBuffer) without input `n` + OpBuilder<(ins "Value":$curSize, "Value":$inBuffer, "Value":$value)> ]; let hasVerifier = 1; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -581,10 +581,8 @@ } void PushBackOp::build(OpBuilder &builder, OperationState &result, - Type outBuffer, Value bufferSizes, Value inBuffer, - Value value, APInt idx) { - build(builder, result, outBuffer, bufferSizes, inBuffer, value, - std::move(idx), Value()); + Value curSize, Value inBuffer, Value value) { + build(builder, result, curSize, inBuffer, value, Value()); } LogicalResult PushBackOp::verify() { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -331,7 +331,7 @@ Location loc = func.getLoc(); ValueRange args = entryBlock->getArguments(); Value p = args[hiIdx]; - SmallVector types(2, p.getType()); // only two + SmallVector types(2, p.getType()); // only two scf::WhileOp whileOp = builder.create( loc, types, SmallVector{args[loIdx], args[hiIdx]}); @@ -490,7 +490,7 @@ Value i = lo; Value j = builder.create(loc, hi, c1); - SmallVector operands{i, j, p}; // exactly three + SmallVector operands{i, j, p}; // exactly three SmallVector types{i.getType(), j.getType(), p.getType()}; scf::WhileOp whileOp = builder.create(loc, types, operands); @@ -770,9 +770,7 @@ Value c0 = constantIndex(rewriter, loc, 0); Value buffer = op.getInBuffer(); Value capacity = rewriter.create(loc, buffer, c0); - Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue()); - Value bufferSizes = op.getBufferSizes(); - Value size = rewriter.create(loc, bufferSizes, idx); + Value size = op.getCurSize(); Value value = op.getValue(); Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1); @@ -853,8 +851,7 @@ } // Update the buffer size. - rewriter.create(loc, newSize, bufferSizes, idx); - rewriter.replaceOp(op, buffer); + rewriter.replaceOp(op, {buffer, newSize}); return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.h @@ -41,12 +41,16 @@ assert(value); // value inherented from StructBuilder must have been set. } - static SparseTensorMetaData undef(OpBuilder &builder, Location loc, - Type metaType); + // Undef value for dimension sizes, all zero value for memory sizes. + static SparseTensorMetaData getInitValue(OpBuilder &builder, Location loc, + Type metaType); static Type getMetaDataIndexType(Value data); Value dimSize(OpBuilder &builder, Location loc, unsigned dim); void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value size); + + Value memSize(OpBuilder &builder, Location loc, unsigned pos); + void setMemSize(OpBuilder &builder, Location loc, unsigned pos, Value size); }; //===----------------------------------------------------------------------===// @@ -59,11 +63,6 @@ // size is the capacity and the used size resides in the memSizes array. // // struct { -// ; sparse tensor metadata -// struct { -// array dimSizes ; size in each dimension -// } -// memref memSizes ; sizes of ptrs/inds/values // ; per-dimension d: // ; if dense: // @@ -73,6 +72,12 @@ // ; if singleton: // memref indices-d ; indices for singleton dim d // memref values ; values +// +// ; sparse tensor metadata +// struct { +// array dimSizes ; sizes for each dimension +// array memSizes; ; sizes for each data memref +// } // }; // //===----------------------------------------------------------------------===// @@ -84,8 +89,7 @@ ValMemRef }; -constexpr uint64_t kMemSizesIdx = 0; -constexpr uint64_t kDataFieldIdx = kMemSizesIdx + 1; +constexpr uint64_t kDataFieldStartingIdx = 0; /// For each field that will be allocated for the given sparse tensor encoding, /// calls the callback with the corresponding field index, field kind, dimension @@ -119,10 +123,12 @@ /// Get the index of the field in memSizes (only valid for data fields). inline unsigned getFieldMemSizesIndex(unsigned fid) { - assert(fid >= kDataFieldIdx); - return fid - kDataFieldIdx; + assert(fid >= kDataFieldStartingIdx); + return fid - kDataFieldStartingIdx; } +Value createIndexCast(OpBuilder &builder, Location loc, Value value, Type to); + /// A helper class around an array of values that corresponding to a sparse /// tensor, provides a set of meaningful APIs to query and update a particular /// field in a consistent way. @@ -142,6 +148,17 @@ using ValueArrayRef = typename std::conditional &, ValueRange>::type; + /// Adds index conversions where needed. + Value toIndex(OpBuilder &builder, Location loc, Value value) const { + assert(value.getType() == getMetaDataIndexType()); + return createIndexCast(builder, loc, value, builder.getIndexType()); + } + + Value toInteger(OpBuilder &builder, Location loc, Value value) const { + assert(value.getType() == builder.getIndexType()); + return createIndexCast(builder, loc, value, getMetaDataIndexType()); + } + public: SparseTensorDescriptorImpl(Type tp, ValueArrayRef fields) : indexType(SparseTensorMetaData::getMetaDataIndexType(fields.back())), @@ -159,7 +176,7 @@ // SparseTensorDescriptor. template > /*implicit*/ SparseTensorDescriptorImpl(std::enable_if_t &mDesc) - : indexType(mDesc.getIndexType()), rType(mDesc.getTensorType()), + : indexType(mDesc.getMetaDataIndexType()), rType(mDesc.getTensorType()), fields(mDesc.getFields()) {} /// @@ -177,15 +194,15 @@ unsigned getValMemRefIndex() const { return fields.size() - 2; } unsigned getPtrMemSizesIndex(unsigned dim) const { - return getPtrMemRefIndex(dim) - kDataFieldIdx; + return getPtrMemRefIndex(dim) - kDataFieldStartingIdx; } unsigned getIdxMemSizesIndex(unsigned dim) const { - return getIdxMemRefIndex(dim) - kDataFieldIdx; + return getIdxMemRefIndex(dim) - kDataFieldStartingIdx; } unsigned getValMemSizesIndex() const { - return getValMemRefIndex() - kDataFieldIdx; + return getValMemRefIndex() - kDataFieldStartingIdx; } unsigned getNumFields() const { return fields.size(); } @@ -196,10 +213,26 @@ Value getDimSize(OpBuilder &builder, Location loc, unsigned dim) const { SparseTensorMetaData md(fields.back()); - return md.dimSize(builder, loc, dim); + return toIndex(builder, loc, md.dimSize(builder, loc, dim)); + } + + Value getPtrMemRefSize(OpBuilder &builder, Location loc, unsigned dim) const { + return getMemSize(builder, loc, getPtrMemSizesIndex(dim)); + } + + Value getIdxMemRefSize(OpBuilder &builder, Location loc, unsigned dim) const { + SparseTensorMetaData md(fields.back()); + return getMemSize(builder, loc, getIdxMemSizesIndex(dim)); } - Value getMemSizesMemRef() const { return fields[kMemSizesIdx]; } + Value getValMemRefSize(OpBuilder &builder, Location loc) const { + return getMemSize(builder, loc, getValMemSizesIndex()); + } + + Value getMemSize(OpBuilder &builder, Location loc, unsigned pos) const { + SparseTensorMetaData md(fields.back()); + return toIndex(builder, loc, md.memSize(builder, loc, pos)); + } Value getPtrMemRef(unsigned ptrDim) const { return fields[getPtrMemRefIndex(ptrDim)]; @@ -230,9 +263,17 @@ template void setDimSize(OpBuilder &builder, Location loc, unsigned dim, std::enable_if_t v) { - assert(v.getType() == getIndexType()); SparseTensorMetaData md(fields.back()); - md.setDimSize(builder, loc, dim, v); + md.setDimSize(builder, loc, dim, toInteger(builder, loc, v)); + fields.back() = md; + } + + template + void setMemSize(OpBuilder &builder, Location loc, unsigned pos, + std::enable_if_t v) { + SparseTensorMetaData md(fields.back()); + md.setMemSize(builder, loc, pos, toInteger(builder, loc, v)); + // Update the metadata SSA value. fields.back() = md; } @@ -244,7 +285,7 @@ RankedTensorType getTensorType() const { return rType; } ValueArrayRef getFields() const { return fields; } - Type getIndexType() const { return indexType; } + Type getMetaDataIndexType() const { return indexType; } Type getElementType(unsigned fidx) const { return fields[fidx].getType().template cast().getElementType(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorBuilder.cpp @@ -5,8 +5,8 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// - #include "SparseTensorBuilder.h" +#include "CodegenUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" @@ -15,14 +15,21 @@ using namespace mlir; using namespace sparse_tensor; -static SmallVector +static SmallVector getSparseTensorMetaDataFields(RankedTensorType rtp, unsigned indexBitwidth) { MLIRContext *ctx = rtp.getContext(); - SmallVector result; + auto enc = getSparseTensorEncoding(rtp); + assert(enc); + + SmallVector result; + auto indexType = IntegerType::get(ctx, indexBitwidth); - auto dimSizes = LLVM::LLVMArrayType::get( - ctx, IntegerType::get(ctx, indexBitwidth), rtp.getRank()); + auto dimSizes = LLVM::LLVMArrayType::get(ctx, indexType, rtp.getRank()); + + auto memSizes = LLVM::LLVMArrayType::get(ctx, indexType, + getNumDataFieldsFromEncoding(enc)); result.push_back(dimSizes); + result.push_back(memSizes); return result; } @@ -32,6 +39,13 @@ rtp.getContext(), getSparseTensorMetaDataFields(rtp, indexBitwidth)); } +Value sparse_tensor::createIndexCast(OpBuilder &builder, Location loc, + Value value, Type to) { + if (value.getType() != to) + return builder.create(loc, to, value); + return value; +} + Optional SparseTensorTypeToBufferConverter::convertSparseTensorType( RankedTensorType rtp, SmallVectorImpl &fields) { @@ -73,11 +87,24 @@ // MetaData constexpr uint64_t kDimSizePosInMetaData = 0; +constexpr uint64_t kMemSizePosInMetaData = 1; -SparseTensorMetaData SparseTensorMetaData::undef(OpBuilder &builder, - Location loc, Type metaType) { +SparseTensorMetaData SparseTensorMetaData::getInitValue(OpBuilder &builder, + Location loc, + Type metaType) { Value metaData = builder.create(loc, metaType); - return SparseTensorMetaData(metaData); + SparseTensorMetaData md(metaData); + auto memSizeArrayType = metaType.cast() + .getBody()[kMemSizePosInMetaData] + .cast(); + + // Fill memSizes array with zero. + for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++) { + md.setMemSize( + builder, loc, i, + constantZero(builder, loc, memSizeArrayType.getElementType())); + } + return md; } Type SparseTensorMetaData::getMetaDataIndexType(Value data) { @@ -88,20 +115,34 @@ .getElementType(); } -/// Builds IR inserting the pos-th size into the descriptor +/// Builds IR inserting the pos-th size into the descriptor. void SparseTensorMetaData::setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value size) { value = builder.create( loc, value, size, ArrayRef({kDimSizePosInMetaData, dim})); } -/// Builds IR inserting the pos-th size into the descriptor +/// Builds IR inserting the pos-th size into the descriptor. Value SparseTensorMetaData::dimSize(OpBuilder &builder, Location loc, unsigned dim) { return builder.create( loc, value, ArrayRef({kDimSizePosInMetaData, dim})); } +/// Builds IR extracting the pos-th memory size into the descriptor. +Value SparseTensorMetaData::memSize(OpBuilder &builder, Location loc, + unsigned pos) { + return builder.create( + loc, value, ArrayRef({kMemSizePosInMetaData, pos})); +} + +/// Builds IR inserting the pos-th memory size into the descriptor. +void SparseTensorMetaData::setMemSize(OpBuilder &builder, Location loc, + unsigned pos, Value size) { + value = builder.create( + loc, value, size, ArrayRef({kMemSizePosInMetaData, pos})); +} + void sparse_tensor::foreachFieldInSparseTensor( const SparseTensorEncodingAttr enc, llvm::function_ref bool { - if (fidx >= kDataFieldIdx) + if (fidx >= kDataFieldStartingIdx) numFields++; return true; }); numFields -= 1; // the last field is MetaData field - assert(numFields == getNumFieldsFromEncoding(enc) - kDataFieldIdx - 1); + assert(numFields == + getNumFieldsFromEncoding(enc) - kDataFieldStartingIdx - 1); return numFields; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -114,9 +114,7 @@ // Any other query can consult the dimSizes array at field DimSizesIdx, // accounting for the reordering applied to the sparse storage. - return toType(builder, loc, - desc.getDimSize(builder, loc, toStoredDim(rtp, dim)), - builder.getIndexType()); + return desc.getDimSize(builder, loc, toStoredDim(rtp, dim)); } // Gets the dimension size at the given stored dimension 'd', either as a @@ -129,8 +127,7 @@ if (!ShapedType::isDynamic(shape[dim])) return constantIndex(builder, loc, shape[dim]); - return toType(builder, loc, desc.getDimSize(builder, loc, d), - builder.getIndexType()); + return desc.getDimSize(builder, loc, d); } static void createPushback(OpBuilder &builder, Location loc, @@ -138,11 +135,14 @@ Value value, Value repeat = Value()) { Type etp = desc.getElementType(fidx); Value field = desc.getField(fidx); - Value newField = builder.create( - loc, field.getType(), desc.getMemSizesMemRef(), field, - toType(builder, loc, value, etp), APInt(64, getFieldMemSizesIndex(fidx)), - repeat); - desc.setField(fidx, newField); + unsigned memSizePos = getFieldMemSizesIndex(fidx); + + auto pushBackOp = builder.create( + loc, desc.getMemSize(builder, loc, memSizePos), field, + toType(builder, loc, value, etp), repeat); + + desc.setField(fidx, pushBackOp.getOutBuffer()); + desc.setMemSize(builder, loc, memSizePos, pushBackOp.getNewSize()); } /// Generates code that allocates a sparse storage scheme for given rank. @@ -214,7 +214,7 @@ Value field; switch (fKind) { case SparseTensorFieldKind::MetaData: - field = SparseTensorMetaData::undef(builder, loc, fType); + field = SparseTensorMetaData::getInitValue(builder, loc, fType); break; case SparseTensorFieldKind::MemSizes: field = @@ -249,17 +249,16 @@ // to all zeros, sets the dimSizes to known values and gives all pointer // fields an initial zero entry, so that it is easier to maintain the // "linear + 1" length property. - builder.create( - loc, constantZero(builder, loc, builder.getIndexType()), - desc.getMemSizesMemRef()); // zero memSizes + // builder.create( + // loc, constantZero(builder, loc, builder.getIndexType()), + // desc.getMemSizesMemRef()); // zero memSizes Value ptrZero = constantZero(builder, loc, getSparseTensorEncoding(rtp).getPointerType()); for (unsigned r = 0; r < rank; r++) { unsigned ro = toOrigDim(rtp, r); // Fills dim sizes array. - desc.setDimSize(builder, loc, r, - toType(builder, loc, sizes[ro], desc.getIndexType())); + desc.setDimSize(builder, loc, r, sizes[ro]); // Pushes a leading zero to pointers memref. if (isCompressedDim(rtp, r)) @@ -303,8 +302,8 @@ Value pp1 = builder.create(loc, pos, one); Value plo = genLoad(builder, loc, desc.getField(ptrIndex), pos); Value phi = genLoad(builder, loc, desc.getField(ptrIndex), pp1); - Value psz = constantIndex(builder, loc, getFieldMemSizesIndex(idxIndex)); - Value msz = genLoad(builder, loc, desc.getMemSizesMemRef(), psz); + Value msz = desc.getMemSize(builder, loc, getFieldMemSizesIndex(idxIndex)); + Value phim1 = builder.create( loc, toType(builder, loc, phi, indexType), one); // Conditional expression. @@ -517,8 +516,7 @@ if (d > 0) { Type ptrType = getSparseTensorEncoding(rtp).getPointerType(); Value ptrMemRef = desc.getPtrMemRef(d); - Value mz = constantIndex(builder, loc, desc.getPtrMemSizesIndex(d)); - Value hi = genLoad(builder, loc, desc.getMemSizesMemRef(), mz); + Value hi = desc.getPtrMemRefSize(builder, loc, d); Value zero = constantIndex(builder, loc, 0); Value one = constantIndex(builder, loc, 1); // Vector of only one, but needed by createFor's prototype. @@ -973,10 +971,9 @@ ConversionPatternRewriter &rewriter) const override { // Query memSizes for the actually stored values size. auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); - Value field = - constantIndex(rewriter, op.getLoc(), desc.getValMemSizesIndex()); - rewriter.replaceOpWithNewOp(op, desc.getMemSizesMemRef(), - field); + rewriter.replaceOp(op, toType(rewriter, op.getLoc(), + desc.getValMemRefSize(rewriter, op.getLoc()), + rewriter.getIndexType())); return success(); } }; diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir --- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir @@ -1,15 +1,14 @@ // RUN: mlir-opt %s -split-input-file --sparse-buffer-rewrite --canonicalize --cse | FileCheck %s // CHECK-LABEL: func @sparse_push_back( -// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[A:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref, -// CHECK-SAME: %[[C:.*]]: f64) -> memref { +// CHECK-SAME: %[[C:.*]]: f64) -> (memref, index) { // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[P1:.*]] = memref.dim %[[B]], %[[C0]] -// CHECK: %[[S1:.*]] = memref.load %[[A]]{{\[}}%[[C2]]] -// CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[C1]] : index +// CHECK: %[[S2:.*]] = arith.addi %[[A]], %[[C1]] : index // CHECK: %[[T:.*]] = arith.cmpi ugt, %[[S2]], %[[P1]] // CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref) { // CHECK: %[[P2:.*]] = arith.muli %[[P1]], %[[C2]] @@ -18,25 +17,23 @@ // CHECK: } else { // CHECK: scf.yield %[[B]] : memref // CHECK: } -// CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[S1]]] -// CHECK: memref.store %[[S2]], %[[A]]{{\[}}%[[C2]]] -// CHECK: return %[[M]] : memref -func.func @sparse_push_back(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { - %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 - return %0 : memref +// CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[A]]] +// CHECK: return %[[M]], %[[S2]] +func.func @sparse_push_back(%arg0: index, %arg1: memref, %arg2: f64) -> (memref, index) { + %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2 : index, memref, f64 + return %0#0, %0#1 : memref, index } // ----- // CHECK-LABEL: func @sparse_push_back_n( -// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[S1:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref, // CHECK-SAME: %[[C:.*]]: f64, -// CHECK-SAME: %[[D:.*]]: index) -> memref { +// CHECK-SAME: %[[D:.*]]: index) -> (memref, index) { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[P1:.*]] = memref.dim %[[B]], %[[C0]] -// CHECK: %[[S1:.*]] = memref.load %[[A]]{{\[}}%[[C2]]] // CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[D]] : index // CHECK: %[[T:.*]] = arith.cmpi ugt, %[[S2]], %[[P1]] // CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref) { @@ -55,29 +52,25 @@ // CHECK: } // CHECK: %[[S:.*]] = memref.subview %[[M]]{{\[}}%[[S1]]] {{\[}}%[[D]]] [1] // CHECK: linalg.fill ins(%[[C]] : f64) outs(%[[S]] -// CHECK: memref.store %[[S2]], %[[A]]{{\[}}%[[C2]]] -// CHECK: return %[[M]] : memref -func.func @sparse_push_back_n(%arg0: memref, %arg1: memref, %arg2: f64, %arg3: index) -> memref { - %0 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 {idx = 2 : index} : memref, memref, f64, index - return %0 : memref +// CHECK: return %[[M]], %[[S2]] : memref, index +func.func @sparse_push_back_n(%arg0: index, %arg1: memref, %arg2: f64, %arg3: index) -> (memref, index) { + %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 : index, memref, f64, index + return %0#0, %0#1 : memref, index } // ----- // CHECK-LABEL: func @sparse_push_back_inbound( -// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[S1:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref, -// CHECK-SAME: %[[C:.*]]: f64) -> memref { +// CHECK-SAME: %[[C:.*]]: f64) -> (memref, index) { // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[S1:.*]] = memref.load %[[A]]{{\[}}%[[C2]]] // CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[C1]] // CHECK: memref.store %[[C]], %[[B]]{{\[}}%[[S1]]] -// CHECK: memref.store %[[S2]], %[[A]]{{\[}}%[[C2]]] -// CHECK: return %[[B]] : memref -func.func @sparse_push_back_inbound(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { - %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 - return %0 : memref +// CHECK: return %[[B]], %[[S2]] : memref, index +func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref, %arg2: f64) -> (memref, index) { + %0:2 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 : index, memref, f64 + return %0#0, %0#1 : memref, index } // ----- diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -47,31 +47,28 @@ }> // CHECK-LABEL: func @sparse_nop( -// CHECK-SAME: %[[A0:.*]]: memref<3xindex>, // CHECK-SAME: %[[A1:.*]]: memref, // CHECK-SAME: %[[A2:.*]]: memref, // CHECK-SAME: %[[A3:.*]]: memref, -// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<1 x i64>)>) -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : -// CHECK-SAME: memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)> +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>) +// CHECK: return %[[A1]], %[[A2]], %[[A3]], %[[A4]] : +// CHECK-SAME: memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)> func.func @sparse_nop(%arg0: tensor) -> tensor { return %arg0 : tensor } // CHECK-LABEL: func @sparse_nop_multi_ret( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<1 x i64>)>, -// CHECK-SAME: %[[A5:.*5]]: memref<3xindex>, -// CHECK-SAME: %[[A6:.*6]]: memref, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: memref, -// CHECK-SAME: %[[A9:.*9]]: !llvm.struct<(array<1 x i64>)>) -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]] : -// CHECK-SAME: memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)>, -// CHECK-SAME: memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)> +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[A6:.*4]]: memref, +// CHECK-SAME: %[[A7:.*5]]: memref, +// CHECK-SAME: %[[A8:.*6]]: memref, +// CHECK-SAME: %[[A9:.*7]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>) +// CHECK: return %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A6]], %[[A7]], %[[A8]], %[[A9]] : +// CHECK-SAME: memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)> func.func @sparse_nop_multi_ret(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { @@ -79,20 +76,18 @@ } // CHECK-LABEL: func @sparse_nop_call( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<1 x i64>)>, -// CHECK-SAME: %[[A5:.*5]]: memref<3xindex>, -// CHECK-SAME: %[[A6:.*6]]: memref, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: memref, -// CHECK-SAME: %[[A9:.*9]]: !llvm.struct<(array<1 x i64>)>) -// CHECK: %[[T:.*]]:10 = call @sparse_nop_multi_ret(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]]) -// CHECK: return %[[T]]#0, %[[T]]#1, %[[T]]#2, %[[T]]#3, %[[T]]#4, %[[T]]#5, %[[T]]#6, %[[T]]#7, %[[T]]#8, %[[T]]#9 : -// CHECK-SAME: memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)>, -// CHECK-SAME: memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)> +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[A6:.*4]]: memref, +// CHECK-SAME: %[[A7:.*5]]: memref, +// CHECK-SAME: %[[A8:.*6]]: memref, +// CHECK-SAME: %[[A9:.*7]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>) +// CHECK: %[[T:.*]]:8 = call @sparse_nop_multi_ret(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A6]], %[[A7]], %[[A8]], %[[A9]]) +// CHECK: return %[[T]]#0, %[[T]]#1, %[[T]]#2, %[[T]]#3, %[[T]]#4, %[[T]]#5, %[[T]]#6, %[[T]]#7 : +// CHECK-SAME: memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)> func.func @sparse_nop_call(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { @@ -103,67 +98,61 @@ } // CHECK-LABEL: func @sparse_nop_cast( -// CHECK-SAME: %[[A0:.*]]: memref<3xindex>, // CHECK-SAME: %[[A1:.*]]: memref, // CHECK-SAME: %[[A2:.*]]: memref, // CHECK-SAME: %[[A3:.*]]: memref, -// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<1 x i64>)>) -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>) +// CHECK: return %[[A1]], %[[A2]], %[[A3]], %[[A4]] : func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor { %0 = tensor.cast %arg0 : tensor<64xf32, #SparseVector> to tensor return %0 : tensor } // CHECK-LABEL: func @sparse_nop_cast_3d( -// CHECK-SAME: %[[A0:.*]]: memref<1xindex>, // CHECK-SAME: %[[A1:.*]]: memref, -// CHECK-SAME: %[[A2:.*]]: !llvm.struct<(array<3 x i64>)>) -// CHECK: return %[[A0]], %[[A1]], %[[A2]] : -// CHECK-SAME: memref<1xindex>, memref, !llvm.struct<(array<3 x i64>)> +// CHECK-SAME: %[[A2:.*]]: !llvm.struct<(array<3 x i64>, array<1 x i64>)>) +// CHECK: return %[[A1]], %[[A2]] : +// CHECK-SAME: memref, !llvm.struct<(array<3 x i64>, array<1 x i64>)> func.func @sparse_nop_cast_3d(%arg0: tensor<10x20x30xf32, #Dense3D>) -> tensor { %0 = tensor.cast %arg0 : tensor<10x20x30xf32, #Dense3D> to tensor return %0 : tensor } // CHECK-LABEL: func @sparse_dense_2d( -// CHECK-SAME: %[[A0:.*]]: memref<1xindex>, // CHECK-SAME: %[[A1:.*]]: memref, -// CHECK-SAME: %[[A2:.*]]: !llvm.struct<(array<2 x i64>)>) { +// CHECK-SAME: %[[A2:.*]]: !llvm.struct<(array<2 x i64>, array<1 x i64>)>) { // CHECK: return func.func @sparse_dense_2d(%arg0: tensor) { return } // CHECK-LABEL: func @sparse_row( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<2 x i64>)>) { +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>) { // CHECK: return func.func @sparse_row(%arg0: tensor) { return } // CHECK-LABEL: func @sparse_csr( -// CHECK-SAME: %[[A0:.*]]: memref<3xindex>, // CHECK-SAME: %[[A1:.*]]: memref, // CHECK-SAME: %[[A2:.*]]: memref, // CHECK-SAME: %[[A3:.*]]: memref, -// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<2 x i64>)>) { +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>) { // CHECK: return func.func @sparse_csr(%arg0: tensor) { return } // CHECK-LABEL: func @sparse_dcsr( -// CHECK-SAME: %[[A0:.*0]]: memref<5xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: !llvm.struct<(array<2 x i64>)>) +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: memref, +// CHECK-SAME: %[[A5:.*4]]: memref, +// CHECK-SAME: %[[A6:.*5]]: !llvm.struct<(array<2 x i64>, array<5 x i64>)>) // CHECK: return func.func @sparse_dcsr(%arg0: tensor) { return @@ -174,9 +163,8 @@ // fold using the original static dimension sizes. // // CHECK-LABEL: func @sparse_dense_3d( -// CHECK-SAME: %[[A0:.*]]: memref<1xindex>, // CHECK-SAME: %[[A1:.*]]: memref, -// CHECK-SAME: %[[A2:.*]]: !llvm.struct<(array<3 x i64>)>) +// CHECK-SAME: %[[A2:.*]]: !llvm.struct<(array<3 x i64>, array<1 x i64>)>) // CHECK: %[[C:.*]] = arith.constant 20 : index // CHECK: return %[[C]] : index func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index { @@ -191,9 +179,8 @@ // since the latter honors the dimOrdering. // // CHECK-LABEL: func @sparse_dense_3d_dyn( -// CHECK-SAME: %[[A0:.*]]: memref<1xindex>, // CHECK-SAME: %[[A1:.*]]: memref, -// CHECK-SAME: %[[A2:.*]]: !llvm.struct<(array<3 x i64>)>) +// CHECK-SAME: %[[A2:.*]]: !llvm.struct<(array<3 x i64>, array<1 x i64>)>) // CHECK: %[[A3:.*]] = llvm.extractvalue %[[A2]][0, 2] // CHECK: %[[A4:.*]] = arith.index_cast %[[A3]] : i64 to index // CHECK: return %[[A4]] : index @@ -204,13 +191,12 @@ } // CHECK-LABEL: func @sparse_pointers_dcsr( -// CHECK-SAME: %[[A0:.*0]]: memref<5xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: !llvm.struct<(array<2 x i64>)> +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: memref, +// CHECK-SAME: %[[A5:.*4]]: memref, +// CHECK-SAME: %[[A6:.*5]]: !llvm.struct<(array<2 x i64>, array<5 x i64>)> // CHECK: return %[[A3]] : memref func.func @sparse_pointers_dcsr(%arg0: tensor) -> memref { %0 = sparse_tensor.pointers %arg0 { dimension = 1 : index } : tensor to memref @@ -218,13 +204,12 @@ } // CHECK-LABEL: func @sparse_indices_dcsr( -// CHECK-SAME: %[[A0:.*0]]: memref<5xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: !llvm.struct<(array<2 x i64>)> +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: memref, +// CHECK-SAME: %[[A5:.*4]]: memref, +// CHECK-SAME: %[[A6:.*5]]: !llvm.struct<(array<2 x i64>, array<5 x i64>)> // CHECK: return %[[A4]] : memref func.func @sparse_indices_dcsr(%arg0: tensor) -> memref { %0 = sparse_tensor.indices %arg0 { dimension = 1 : index } : tensor to memref @@ -232,13 +217,12 @@ } // CHECK-LABEL: func @sparse_values_dcsr( -// CHECK-SAME: %[[A0:.*0]]: memref<5xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: !llvm.struct<(array<2 x i64>)>) +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: memref, +// CHECK-SAME: %[[A5:.*4]]: memref, +// CHECK-SAME: %[[A6:.*5]]: !llvm.struct<(array<2 x i64>, array<5 x i64>)>) // CHECK: return %[[A5]] : memref func.func @sparse_values_dcsr(%arg0: tensor) -> memref { %0 = sparse_tensor.values %arg0 : tensor to memref @@ -246,13 +230,12 @@ } // CHECK-LABEL: func @sparse_noe( -// CHECK-SAME: %[[A0:.*]]: memref<3xindex>, // CHECK-SAME: %[[A1:.*]]: memref, // CHECK-SAME: %[[A2:.*]]: memref, // CHECK-SAME: %[[A3:.*]]: memref, -// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<1 x i64>)>) -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[NOE:.*]] = memref.load %[[A0]][%[[C2]]] : memref<3xindex> +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>) +// CHECK: %[[A5:.*]] = llvm.extractvalue %[[A4]][1, 2] : !llvm.struct<(array<1 x i64>, array<3 x i64>)> +// CHECK: %[[NOE:.*]] = arith.index_cast %[[A5]] : i64 to index // CHECK: return %[[NOE]] : index func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index { %0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector> @@ -260,12 +243,10 @@ } // CHECK-LABEL: func @sparse_dealloc_csr( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<2 x i64>)>) -// CHECK: memref.dealloc %[[A0]] : memref<3xindex> +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>) // CHECK: memref.dealloc %[[A1]] : memref // CHECK: memref.dealloc %[[A2]] : memref // CHECK: memref.dealloc %[[A3]] : memref @@ -275,48 +256,58 @@ return } -// CHECK-LABEL: func @sparse_alloc_csc( -// CHECK-SAME: %[[A0:.*]]: index) -// CHECK-DAG: %[[A1:.*]] = arith.constant 10 : i64 -// CHECK-DAG: %[[A2:.*]] = arith.constant 0 : index -// CHECK: %[[A3:.*]] = memref.alloc() : memref<3xindex> +// CHECK-LABEL: func.func @sparse_alloc_csc( +// CHECK-SAME: %[[A0:.*]]: index) -> +// CHECK-SAME: (memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)>) +// CHECK-DAG: %[[A1:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[A2:.*]] = arith.constant 10 : i64 +// CHECK-DAG: %[[A3:.*]] = arith.constant 0 : i64 // CHECK: %[[A4:.*]] = memref.alloc() : memref<16xindex> -// CHECK: %[[A5:.*]] = memref.cast %[[A4]] : memref<16xindex> to memref +// CHECK: %[[A5:.*]] = memref.cast %[[A4]] // CHECK: %[[A6:.*]] = memref.alloc() : memref<16xindex> -// CHECK: %[[A7:.*]] = memref.cast %[[A6]] : memref<16xindex> to memref +// CHECK: %[[A7:.*]] = memref.cast %[[A6]] // CHECK: %[[A8:.*]] = memref.alloc() : memref<16xf64> -// CHECK: %[[A9:.*]] = memref.cast %[[A8]] : memref<16xf64> to memref -// CHECK: %[[A10:.*]] = llvm.mlir.undef : !llvm.struct<(array<2 x i64>)> -// CHECK: linalg.fill ins(%[[A2]] : index) outs(%[[A3]] : memref<3xindex>) -// CHECK: %[[A11:.*]] = arith.index_cast %[[A0]] : index to i64 -// CHECK: %[[A12:.*]] = llvm.insertvalue %[[A11]], %[[A10]][0, 0] -// CHECK: %[[A13:.*]] = llvm.insertvalue %[[A1]], %[[A12]][0, 1] -// CHECK: %[[A14:.*]] = sparse_tensor.push_back %[[A3]], %[[A5]], %[[A2]] {idx = 0 : index} : memref<3xindex>, memref, index -// CHECK: %[[A15:.*]] = sparse_tensor.push_back %[[A3]], %[[A14]], %[[A2]], %[[A0]] {idx = 0 : index} : memref<3xindex>, memref, index, index -// CHECK: return %[[A3]], %[[A15]], %[[A7]], %[[A9]], %[[A13]] : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)> +// CHECK: %[[A9:.*]] = memref.cast %[[A8]] +// CHECK: %[[A10:.*]] = llvm.mlir.undef +// CHECK: %[[A11:.*]] = llvm.insertvalue %[[A3]], %[[A10]][1, 0] +// CHECK: %[[A12:.*]] = llvm.insertvalue %[[A3]], %[[A11]][1, 1] +// CHECK: %[[A13:.*]] = llvm.insertvalue %[[A3]], %[[A12]][1, 2] +// CHECK: %[[A14:.*]] = arith.index_cast %[[A0]] : index to i64 +// CHECK: %[[A15:.*]] = llvm.insertvalue %[[A14]], %[[A13]][0, 0] +// CHECK: %[[A16:.*]] = llvm.insertvalue %[[A2]], %[[A15]][0, 1] +// CHECK: %[[A17:.*]], %[[A18:.*]] = sparse_tensor.push_back %[[A1]], %[[A5]], %[[A1]] : index, memref, index +// CHECK: %[[A19:.*]] = arith.index_cast %[[A18]] : index to i64 +// CHECK: %[[A20:.*]] = llvm.insertvalue %[[A19]], %[[A16]][1, 0] +// CHECK: %[[A21:.*]], %[[A22:.*]] = sparse_tensor.push_back %[[A18]], %[[A17]], %[[A1]], %[[A0]] : index, memref, index, index +// CHECK: %[[A23:.*]] = arith.index_cast %[[A22]] : index to i64 +// CHECK: %[[A24:.*]] = llvm.insertvalue %[[A23]], %[[A20]][1, 0] +// CHECK: return %[[A21]], %[[A7]], %[[A9]], %[[A24]] : memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)> func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> { %0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSC> %1 = sparse_tensor.load %0 : tensor<10x?xf64, #CSC> return %1 : tensor<10x?xf64, #CSC> } -// CHECK-LABEL: func @sparse_alloc_3d() -// CHECK-DAG: %[[A0:.*]] = arith.constant 6000 : index -// CHECK-DAG: %[[A1:.*]] = arith.constant 20 : i64 -// CHECK-DAG: %[[A2:.*]] = arith.constant 10 : i64 -// CHECK-DAG: %[[A3:.*]] = arith.constant 30 : i64 -// CHECK-DAG: %[[A4:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[A5:.*]] = arith.constant 0 : index -// CHECK: %[[A6:.*]] = memref.alloc() : memref<1xindex> +// CHECK-LABEL: func.func @sparse_alloc_3d() +// CHECK-SAME: -> (memref, !llvm.struct<(array<3 x i64>, array<1 x i64>)>) { +// CHECK-DAG: %[[A0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[A1:.*]] = arith.constant 6000 : index +// CHECK-DAG: %[[A2:.*]] = arith.constant 20 : i64 +// CHECK-DAG: %[[A3:.*]] = arith.constant 10 : i64 +// CHECK-DAG: %[[A4:.*]] = arith.constant 30 : i64 +// CHECK-DAG: %[[A5:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[A6:.*]] = arith.constant 0 : i64 // CHECK: %[[A7:.*]] = memref.alloc() : memref<16xf64> // CHECK: %[[A8:.*]] = memref.cast %[[A7]] : memref<16xf64> to memref -// CHECK: %[[A9:.*]] = llvm.mlir.undef : !llvm.struct<(array<3 x i64>)> -// CHECK: linalg.fill ins(%[[A5]] : index) outs(%[[A6]] : memref<1xindex>) -// CHECK: %[[A10:.*]] = llvm.insertvalue %[[A3]], %[[A9]][0, 0] -// CHECK: %[[A11:.*]] = llvm.insertvalue %[[A2]], %[[A10]][0, 1] -// CHECK: %[[A12:.*]] = llvm.insertvalue %[[A1]], %[[A11]][0, 2] -// CHECK: %[[A13:.*]] = sparse_tensor.push_back %[[A6]], %[[A8]], %[[A4]], %[[A0]] {idx = 0 : index} : memref<1xindex>, memref, f64, index -// CHECK: return %[[A6]], %[[A13]], %[[A12]] : memref<1xindex>, memref, !llvm.struct<(array<3 x i64>)> +// CHECK: %[[A9:.*]] = llvm.mlir.undef +// CHECK: %[[A10:.*]] = llvm.insertvalue %[[A6]], %[[A9]][1, 0] +// CHECK: %[[A11:.*]] = llvm.insertvalue %[[A4]], %[[A10]][0, 0] +// CHECK: %[[A12:.*]] = llvm.insertvalue %[[A3]], %[[A11]][0, 1] +// CHECK: %[[A13:.*]] = llvm.insertvalue %[[A2]], %[[A12]][0, 2] +// CHECK: %[[A14:.*]], %[[A15:.*]] = sparse_tensor.push_back %[[A0]], %[[A8]], %[[A5]], %[[A1]] : index, memref, f64, index +// CHECK: %[[A16:.*]] = arith.index_cast %[[A15]] : index to i64 +// CHECK: %[[A17:.*]] = llvm.insertvalue %[[A16]], %[[A13]][1, 0] +// CHECK: return %[[A14]], %[[A17]] : memref, !llvm.struct<(array<3 x i64>, array<1 x i64>)> func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> { %0 = bufferization.alloc_tensor() : tensor<10x20x30xf64, #Dense3D> %1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D> @@ -370,43 +361,39 @@ } // CHECK-LABEL: func.func private @_insert_C_100_f64_0_0( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<1 x i64>)>, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: f64) -// CHECK: %[[A22:.*]] = sparse_tensor.push_back %[[A0]], %[[A3]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[A0]], %[[A1]], %{{.*}}, %[[A22]], %[[A4]] : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)> +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[A5:.*4]]: index, +// CHECK-SAME: %[[A6:.*5]]: f64) // -// CHECK-LABEL: func @sparse_compression_1d( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, +// CHECK-LABEL: func.func @sparse_compression_1d( +// CHECK-SAME: %[[A0:.*0]]: memref, // CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<1 x i64>)>, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: index) -// CHECK-DAG: %[[A9:.*]] = arith.constant false -// CHECK-DAG: %[[A10:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[A11:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[A12:.*]] = arith.constant 0 : index -// CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref -// CHECK: %[[A13:.*]]:5 = scf.for %[[A14:.*]] = %[[A12]] to %[[A8]] step %[[A11]] iter_args(%[[A15:.*]] = %[[A0]], %[[A16:.*]] = %[[A1]], %[[A17:.*]] = %[[A2]], %[[A18:.*]] = %[[A3]], %[[A19:.*]] = %[[A4]]) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)>) { -// CHECK: %[[A20:.*]] = memref.load %[[A7]]{{\[}}%[[A14]]] : memref -// CHECK: %[[A21:.*]] = memref.load %[[A5]]{{\[}}%[[A20]]] : memref -// CHECK: %[[A22:.*]]:5 = func.call @_insert_C_100_f64_0_0(%[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A19]], %[[A20]], %[[A21]]) : (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)>, index, f64) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)>) -// CHECK: memref.store %[[A10]], %[[A5]]{{\[}}%[[A20]]] : memref -// CHECK: memref.store %[[A9]], %[[A6]]{{\[}}%[[A20]]] : memref -// CHECK: scf.yield %[[A22]]#0, %[[A22]]#1, %[[A22]]#2, %[[A22]]#3, %[[A22]]#4 : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)> +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: index) -> (memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)>) { +// CHECK-DAG: %[[A8:.*]] = arith.constant false +// CHECK-DAG: %[[A9:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[A10:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[A11:.*]] = arith.constant 0 : index +// CHECK: sparse_tensor.sort %[[A7]], %[[A6]] : memref +// CHECK: %[[A12:.*]]:4 = scf.for %[[A13:.*]] = %[[A11]] to %[[A7]] step %[[A10]] iter_args(%[[A14:.*]] = %[[A0]], %[[A15:.*]] = %[[A1]], %[[A16:.*]] = %[[A2]], %[[A17:.*]] = %[[A3]]) +// CHECK: %[[A18:.*]] = memref.load %[[A6]]{{\[}}%[[A13]]] : memref +// CHECK: %[[A19:.*]] = memref.load %[[A4]]{{\[}}%[[A18]]] : memref +// CHECK: %[[A20:.*]]:4 = func.call @_insert_C_100_f64_0_0(%[[A14]], %[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A19]]) +// CHECK: memref.store %[[A9]], %[[A4]]{{\[}}%[[A18]]] : memref +// CHECK: memref.store %[[A8]], %[[A5]]{{\[}}%[[A18]]] : memref +// CHECK: scf.yield %[[A20]]#0, %[[A20]]#1, %[[A20]]#2, %[[A20]]#3 // CHECK: } -// CHECK: memref.dealloc %[[A5]] : memref -// CHECK: memref.dealloc %[[A6]] : memref -// CHECK: memref.dealloc %[[A7]] : memref -// CHECK: return %[[A23:.*]]#0, %[[A23]]#1, %[[A23]]#2, %[[A23]]#3, %[[A23]]#4 : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)> +// CHECK: memref.dealloc %[[A4]] : memref +// CHECK: memref.dealloc %[[A5]] : memref +// CHECK: memref.dealloc %[[A6]] : memref +// CHECK: return %[[A21:.*]]#0, %[[A21]]#1, %[[A21]]#2, %[[A21]]#3 : memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)> func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>, %values: memref, %filled: memref, @@ -419,46 +406,54 @@ } // CHECK-LABEL: func.func private @_insert_D_C_8_8_f64_64_32( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<2 x i64>)>, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: index, -// CHECK-SAME: %[[A7:.*7]]: f64) -// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A0]], %[[A3]], %[[A7]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[A0]], %[[A1]], %{{.*}}, %[[PV]], %[[A4]] -/// -// CHECK-LABEL: func @sparse_compression( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<2 x i64>)>, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: index, -// CHECK-SAME: %[[A9:.*9]]: index) -// CHECK-DAG: %[[A10:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[A11:.*]] = arith.constant false -// CHECK-DAG: %[[A12:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[A13:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[A14:.*]] = arith.constant 0 : index -// CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref -// CHECK: %[[R:.*]]:5 = scf.for %[[A16:.*]] = %[[A14]] to %[[A8]] step %[[A13]] iter_args(%[[A17:.*]] = %[[A0]], %[[A18:.*]] = %[[A1]], %[[A19:.*]] = %[[A2]], %[[A20:.*]] = %[[A3]], %[[A21:.*]] = %[[A4]]) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)>) { -// CHECK: %[[A22:.*]] = memref.load %[[A7]]{{\[}}%[[A16]]] : memref -// CHECK: %[[A23:.*]] = memref.load %[[A5]]{{\[}}%[[A22]]] : memref -// CHECK: %[[A24:.*]]:5 = func.call @_insert_D_C_8_8_f64_64_32(%[[A17]], %[[A18]], %[[A19]], %[[A20]], %[[A21]], %[[A9]], %[[A22]], %[[A23]]) -// CHECK: memref.store %[[A12]], %[[A5]]{{\[}}%[[A22]]] : memref -// CHECK: memref.store %[[A11]], %[[A6]]{{\[}}%[[A22]]] : memref -// CHECK: scf.yield %[[A24]]#0, %[[A24]]#1, %[[A24]]#2, %[[A24]]#3, %[[A24]]#4 : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)> +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[A5:.*4]]: index, +// CHECK-SAME: %[[A6:.*5]]: index, +// CHECK-SAME: %[[A7:.*6]]: f64) +// +// CHECK-LABEL: func.func @sparse_compression( +// CHECK-SAME: %[[A0:.*0]]: memref, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: index, +// CHECK-SAME: %[[A8:.*8]]: index) -> (memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)>) { +// CHECK-DAG: %[[A9:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[A10:.*]] = arith.constant false +// CHECK-DAG: %[[A11:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[A12:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[A13:.*]] = arith.constant 0 : index +// CHECK: sparse_tensor.sort %[[A7]], %[[A6]] : memref +// CHECK: %[[A14:.*]]:4 = scf.for %[[A15:.*]] = %[[A13]] to %[[A7]] step %[[A12]] iter_args(%[[A16:.*]] = %[[A0]], %[[A17:.*]] = %[[A1]], %[[A18:.*]] = %[[A2]], %[[A19:.*]] = %[[A3]]) +// CHECK: %[[A20:.*]] = memref.load %[[A6]]{{\[}}%[[A15]]] : memref +// CHECK: %[[A21:.*]] = memref.load %[[A4]]{{\[}}%[[A20]]] : memref +// CHECK: %[[A22:.*]]:4 = func.call @_insert_D_C_8_8_f64_64_32(%[[A16]], %[[A17]], %[[A18]], %[[A19]], %[[A8]], %[[A20]], %[[A21]]) +// CHECK: memref.store %[[A11]], %[[A4]]{{\[}}%[[A20]]] : memref +// CHECK: memref.store %[[A10]], %[[A5]]{{\[}}%[[A20]]] : memref +// CHECK: scf.yield %[[A22]]#0, %[[A22]]#1, %[[A22]]#2, %[[A22]]#3 +// CHECK: } +// CHECK: memref.dealloc %[[A4]] : memref +// CHECK: memref.dealloc %[[A5]] : memref +// CHECK: memref.dealloc %[[A6]] : memref +// CHECK: %[[A23:.*]] = llvm.extractvalue %[[A24:.*]]#3[1, 0] +// CHECK: %[[A25:.*]] = arith.index_cast %[[A23]] : i64 to index +// CHECK: %[[A26:.*]] = memref.load %[[A24]]#0{{\[}}%[[A13]]] : memref +// CHECK: %[[A27:.*]] = scf.for %[[A28:.*]] = %[[A12]] to %[[A25]] step %[[A12]] iter_args(%[[A29:.*]] = %[[A26]]) -> (i32) { +// CHECK: %[[A30:.*]] = memref.load %[[A24]]#0{{\[}}%[[A28]]] : memref +// CHECK: %[[A31:.*]] = arith.cmpi eq, %[[A30]], %[[A9]] : i32 +// CHECK: %[[A32:.*]] = arith.select %[[A31]], %[[A29]], %[[A30]] : i32 +// CHECK: scf.if %[[A31]] { +// CHECK: memref.store %[[A29]], %[[A24]]#0{{\[}}%[[A28]]] : memref +// CHECK: } +// CHECK: scf.yield %[[A32]] : i32 // CHECK: } -// CHECK: memref.dealloc %[[A5]] : memref -// CHECK: memref.dealloc %[[A6]] : memref -// CHECK: memref.dealloc %[[A7]] : memref -// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)> +// CHECK: return %[[A24]]#0, %[[A24]]#1, %[[A24]]#2, %[[A24]]#3 : memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)> func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>, %values: memref, %filled: memref, @@ -472,45 +467,52 @@ } // CHECK-LABEL: func.func private @_insert_D_C_8_8_f64_0_0( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<2 x i64>)>, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: index, -// CHECK-SAME: %[[A7:.*7]]: f64) -// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A0]], %[[A3]], %[[A7]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[A0]], %[[A1]], %{{.*}}, %[[PV]], %[[A4]] +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[A5:.*4]]: index, +// CHECK-SAME: %[[A6:.*5]]: index, +// CHECK-SAME: %[[A7:.*6]]: f64) // -// CHECK-LABEL: func @sparse_compression_unordered( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, +// CHECK-LABEL: func.func @sparse_compression_unordered( +// CHECK-SAME: %[[A0:.*0]]: memref, // CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<2 x i64>)>, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: index, -// CHECK-SAME: %[[A9:.*9]]: index) -// CHECK-DAG: %[[A10:.*]] = arith.constant false -// CHECK-DAG: %[[A11:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[A12:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[A13:.*]] = arith.constant 1 : index -// CHECK-NOT: sparse_tensor.sort -// CHECK: %[[R:.*]]:5 = scf.for %[[A15:.*]] = %[[A12]] to %[[A8]] step %[[A13]] iter_args(%[[A16:.*]] = %[[A0]], %[[A17:.*]] = %[[A1]], %[[A18:.*]] = %[[A2]], %[[A19:.*]] = %[[A3]], %[[A20:.*]] = %[[A4]]) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)>) { -// CHECK: %[[A21:.*]] = memref.load %[[A7]]{{\[}}%[[A15]]] : memref -// CHECK: %[[A22:.*]] = memref.load %[[A5]]{{\[}}%[[A21]]] : memref -// CHECK: %[[A23:.*]]:5 = func.call @_insert_D_C_8_8_f64_0_0(%[[A16]], %[[A17]], %[[A18]], %[[A19]], %[[A20]], %[[A9]], %[[A21]], %[[A22]]) : (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)>, index, index, f64) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)>) -// CHECK: memref.store %[[A11]], %[[A5]]{{\[}}%[[A21]]] : memref -// CHECK: memref.store %[[A10]], %[[A6]]{{\[}}%[[A21]]] : memref -// CHECK: scf.yield %[[A23]]#0, %[[A23]]#1, %[[A23]]#2, %[[A23]]#3, %[[A23]]#4 : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)> +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: index, +// CHECK-SAME: %[[A8:.*8]]: index) +// CHECK-DAG: %[[A9:.*]] = arith.constant false +// CHECK-DAG: %[[A10:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[A11:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[A12:.*]] = arith.constant 1 : index +// CHECK: %[[A13:.*]]:4 = scf.for %[[A14:.*]] = %[[A11]] to %[[A7]] step %[[A12]] iter_args(%[[A15:.*]] = %[[A0]], %[[A16:.*]] = %[[A1]], %[[A17:.*]] = %[[A2]], %[[A18:.*]] = %[[A3]]) -> (memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)>) { +// CHECK: %[[A19:.*]] = memref.load %[[A6]]{{\[}}%[[A14]]] : memref +// CHECK: %[[A20:.*]] = memref.load %[[A4]]{{\[}}%[[A19]]] : memref +// CHECK: %[[A21:.*]]:4 = func.call @_insert_D_C_8_8_f64_0_0(%[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A8]], %[[A19]], %[[A20]]) +// CHECK: memref.store %[[A10]], %[[A4]]{{\[}}%[[A19]]] : memref +// CHECK: memref.store %[[A9]], %[[A5]]{{\[}}%[[A19]]] : memref +// CHECK: scf.yield %[[A21]]#0, %[[A21]]#1, %[[A21]]#2, %[[A21]]#3 +// CHECK: } +// CHECK: memref.dealloc %[[A4]] : memref +// CHECK: memref.dealloc %[[A5]] : memref +// CHECK: memref.dealloc %[[A6]] : memref +// CHECK: %[[A22:.*]] = llvm.extractvalue %[[A23:.*]]#3[1, 0] +// CHECK: %[[A24:.*]] = arith.index_cast %[[A22]] : i64 to index +// CHECK: %[[A25:.*]] = memref.load %[[A23]]#0{{\[}}%[[A11]]] : memref +// CHECK: %[[A26:.*]] = scf.for %[[A27:.*]] = %[[A12]] to %[[A24]] step %[[A12]] iter_args(%[[A28:.*]] = %[[A25]]) -> (index) { +// CHECK: %[[A29:.*]] = memref.load %[[A23]]#0{{\[}}%[[A27]]] : memref +// CHECK: %[[A30:.*]] = arith.cmpi eq, %[[A29]], %[[A11]] : index +// CHECK: %[[A31:.*]] = arith.select %[[A30]], %[[A28]], %[[A29]] : index +// CHECK: scf.if %[[A30]] { +// CHECK: memref.store %[[A28]], %[[A23]]#0{{\[}}%[[A27]]] : memref +// CHECK: } +// CHECK: scf.yield %[[A31]] : index // CHECK: } -// CHECK: memref.dealloc %[[A5]] : memref -// CHECK: memref.dealloc %[[A6]] : memref -// CHECK: memref.dealloc %[[A7]] : memref -// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)> +// CHECK: return %[[A23]]#0, %[[A23]]#1, %[[A23]]#2, %[[A23]]#3 : memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)> func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>, %values: memref, %filled: memref, @@ -524,26 +526,22 @@ } // CHECK-LABEL: func.func private @_insert_C_128_f64_0_0( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<1 x i64>)>, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: f64) -// CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A0]], %[[A3]], %[[A6]] {idx = 2 : index} -// CHECK: return %[[A0]], %[[A1]], %{{.*}}, %[[P]], %[[A4]] - +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[A5:.*4]]: index, +// CHECK-SAME: %[[A6:.*5]]: f64) +// // CHECK-LABEL: func @sparse_insert( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<1 x i64>)>, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: f64) -// CHECK: %[[R:.*]]:5 = call @_insert_C_128_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]]) -// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[A5:.*4]]: index, +// CHECK-SAME: %[[A6:.*5]]: f64) +// CHECK: %[[R:.*]]:4 = call @_insert_C_128_f64_0_0(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]]) +// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3 func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> { %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV> %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SV> @@ -551,26 +549,22 @@ } // CHECK-LABEL: func.func private @_insert_C_128_f64_64_32( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<1 x i64>)>, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: f64) -// CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A0]], %[[A3]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[A0]], %[[A1]], %{{.*}}, %[[P]], %[[A4]] - +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[A5:.*]]: index, +// CHECK-SAME: %[[A6:.*]]: f64) +// // CHECK-LABEL: func @sparse_insert_typed( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: !llvm.struct<(array<1 x i64>)>, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: f64) -// CHECK: %[[R:.*]]:5 = call @_insert_C_128_f64_64_32(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]]) -// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[A5:.*]]: index, +// CHECK-SAME: %[[A6:.*]]: f64) +// CHECK: %[[R:.*]]:4 = call @_insert_C_128_f64_64_32(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]]) +// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3 func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> { %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector> %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SparseVector> @@ -578,13 +572,12 @@ } // CHECK-LABEL: func.func @sparse_nop_convert( -// CHECK-SAME: %[[A0:.*]]: memref<3xindex>, // CHECK-SAME: %[[A1:.*]]: memref, // CHECK-SAME: %[[A2:.*]]: memref, // CHECK-SAME: %[[A3:.*]]: memref, -// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<1 x i64>)>) -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : -// CHECK-SAME: memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)> +// CHECK-SAME: %[[A4:.*]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>) +// CHECK: return %[[A1]], %[[A2]], %[[A3]], %[[A4]] : +// CHECK-SAME: memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)> func.func @sparse_nop_convert(%arg0: tensor<32xf32, #SparseVector>) -> tensor { %0 = sparse_tensor.convert %arg0 : tensor<32xf32, #SparseVector> to tensor return %0 : tensor diff --git a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir --- a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir @@ -2,29 +2,35 @@ #SV = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> -// CHECK-LABEL: func @sparse_alloc_sparse_vector( -// CHECK-SAME: %[[A:.*]]: index) -> -// CHECK-SAME: memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)> -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK: %[[T1:.*]] = memref.alloc() : memref<3xindex> -// CHECK: %[[T2:.*]] = memref.alloc() : memref<16xindex> -// CHECK: %[[T3:.*]] = memref.cast %[[T2]] : memref<16xindex> to memref -// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T2]] : memref<16xindex>) -// CHECK: %[[T4:.*]] = memref.alloc() : memref<16xindex> -// CHECK: %[[T5:.*]] = memref.cast %[[T4]] : memref<16xindex> to memref -// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T4]] : memref<16xindex>) -// CHECK: %[[T6:.*]] = memref.alloc() : memref<16xf64> -// CHECK: %[[T7:.*]] = memref.cast %[[T6]] : memref<16xf64> to memref -// CHECK: linalg.fill ins(%[[F0]] : f64) outs(%[[T6]] : memref<16xf64>) -// CHECK: %[[T11:.*]] = llvm.mlir.undef : !llvm.struct<(array<1 x i64>)> -// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T1]] : memref<3xindex>) -// CHECK: %[[T12:.*]] = arith.index_cast %[[A]] : index to i64 -// CHECK: %[[MD:.*]] = llvm.insertvalue %[[T12]], %[[T11]][0, 0] -// CHECK: %[[P0:.*]] = sparse_tensor.push_back %[[T1]], %[[T3]] -// CHECK: %[[P1:.*]] = sparse_tensor.push_back %[[T1]], %[[P0]] -// CHECK: return %[[T1]], %[[P1]], %[[T5]], %[[T7]], %[[MD]] +// CHECK-LABEL: func.func @sparse_alloc_sparse_vector( +// CHECK-SAME: %[[VAL_0:.*]]: index) -> +// CHECK-SAME: (memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)>) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[I0:.*]] = arith.constant 0 : i64 +// CHECK: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[VAL_5:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[VAL_6:.*]] = memref.cast %[[VAL_5]] : memref<16xindex> to memref +// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[VAL_5]] : memref<16xindex>) +// CHECK: %[[VAL_7:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[VAL_8:.*]] = memref.cast %[[VAL_7]] : memref<16xindex> to memref +// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[VAL_7]] : memref<16xindex>) +// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<16xf64> +// CHECK: %[[VAL_10:.*]] = memref.cast %[[VAL_9]] : memref<16xf64> to memref +// CHECK: linalg.fill ins(%[[F0]] : f64) outs(%[[VAL_9]] : memref<16xf64>) +// CHECK: %[[MD1:.*]] = llvm.mlir.undef : !llvm.struct<(array<1 x i64>, array<3 x i64>)> +// CHECK: %[[MD2:.*]] = llvm.insertvalue %[[I0]], %[[MD1]][1, 0] +// CHECK: %[[MD3:.*]] = llvm.insertvalue %[[I0]], %[[MD2]][1, 1] +// CHECK: %[[MD4:.*]] = llvm.insertvalue %[[I0]], %[[MD3]][1, 2] +// CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_0]] : index to i64 +// CHECK: %[[MD5:.*]] = llvm.insertvalue %[[VAL_15]], %[[MD4]][0, 0] +// CHECK: %[[VAL_17:.*]], %[[VAL_18:.*]] = sparse_tensor.push_back %[[C0]], %[[VAL_6]], %[[C0]] +// CHECK: %[[VAL_19:.*]] = arith.index_cast %[[VAL_18]] : index to i64 +// CHECK: %[[MD6:.*]] = llvm.insertvalue %[[VAL_19]], %[[MD5]][1, 0] +// CHECK: %[[VAL_21:.*]], %[[VAL_22:.*]] = sparse_tensor.push_back %[[VAL_18]], %[[VAL_17]], %[[C0]], %[[C1]] +// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_22]] : index to i64 +// CHECK: %[[MD:.*]] = llvm.insertvalue %[[VAL_23]], %[[MD6]][1, 0] +// CHECK: return %[[VAL_21]], %[[VAL_8]], %[[VAL_10]], %[[MD]] func.func @sparse_alloc_sparse_vector(%arg0: index) -> tensor { %0 = bufferization.alloc_tensor(%arg0) : tensor %1 = sparse_tensor.load %0 : tensor diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -124,19 +124,19 @@ // ----- -func.func @sparse_push_back(%arg0: memref, %arg1: memref, %arg2: f32) -> memref { +func.func @sparse_push_back(%arg0: index, %arg1: memref, %arg2: f32) -> (memref, index) { // expected-error@+1 {{'sparse_tensor.push_back' op failed to verify that value type matches element type of inBuffer}} - %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f32 - return %0 : memref + %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2 : index, memref, f32 + return %0#0, %0#1 : memref, index } // ----- -func.func @sparse_push_back_n(%arg0: memref, %arg1: memref, %arg2: f32) -> memref { +func.func @sparse_push_back_n(%arg0: index, %arg1: memref, %arg2: f32) -> (memref, index) { %c0 = arith.constant 0: index // expected-error@+1 {{'sparse_tensor.push_back' op n must be not less than 1}} - %0 = sparse_tensor.push_back %arg0, %arg1, %arg2, %c0 {idx = 2 : index} : memref, memref, f32, index - return %0 : memref + %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2, %c0 : index, memref, f32, index + return %0#0, %0#1 : memref, index } // ----- diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -159,41 +159,41 @@ // ----- // CHECK-LABEL: func @sparse_push_back( -// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[A:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref, -// CHECK-SAME: %[[C:.*]]: f64) -> memref { -// CHECK: %[[D:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref, memref, f64 +// CHECK-SAME: %[[C:.*]]: f64) -> (memref, index) { +// CHECK: %[[D:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]] : index, memref, f64 // CHECK: return %[[D]] -func.func @sparse_push_back(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { - %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 - return %0 : memref +func.func @sparse_push_back(%arg0: index, %arg1: memref, %arg2: f64) -> (memref, index) { + %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2 : index, memref, f64 + return %0#0, %0#1 : memref, index } // ----- // CHECK-LABEL: func @sparse_push_back_inbound( -// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[A:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref, -// CHECK-SAME: %[[C:.*]]: f64) -> memref { -// CHECK: %[[D:.*]] = sparse_tensor.push_back inbounds %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref, memref, f64 +// CHECK-SAME: %[[C:.*]]: f64) -> (memref, index) { +// CHECK: %[[D:.*]] = sparse_tensor.push_back inbounds %[[A]], %[[B]], %[[C]] : index, memref, f64 // CHECK: return %[[D]] -func.func @sparse_push_back_inbound(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { - %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 - return %0 : memref +func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref, %arg2: f64) -> (memref, index) { + %0:2 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 : index, memref, f64 + return %0#0, %0#1 : memref, index } // ----- // CHECK-LABEL: func @sparse_push_back_n( -// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[A:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref, // CHECK-SAME: %[[C:.*]]: f64, -// CHECK-SAME: %[[D:.*]]: index) -> memref { -// CHECK: %[[E:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]], %[[D]] {idx = 2 : index} : memref, memref, f64, index +// CHECK-SAME: %[[D:.*]]: index) -> (memref, index) { +// CHECK: %[[E:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]], %[[D]] : index, memref, f64, index // CHECK: return %[[E]] -func.func @sparse_push_back_n(%arg0: memref, %arg1: memref, %arg2: f64, %arg3: index) -> memref { - %0 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 {idx = 2 : index} : memref, memref, f64, index - return %0 : memref +func.func @sparse_push_back_n(%arg0: index, %arg1: memref, %arg2: f64, %arg3: index) -> (memref, index) { + %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 : index, memref, f64, index + return %0#0, %0#1 : memref, index } // ----- diff --git a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir --- a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir @@ -3,24 +3,21 @@ #SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> // CHECK-LABEL: func.func @for( -// CHECK-SAME: %[[VAL_0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[VAL_1:.*1]]: memref, -// CHECK-SAME: %[[VAL_2:.*2]]: memref, -// CHECK-SAME: %[[VAL_3:.*3]]: memref, -// CHECK-SAME: %[[VAL_4:.*4]]: !llvm.struct<(array<1 x i64>)>, -// CHECK-SAME: %[[VAL_5:.*5]]: index, -// CHECK-SAME: %[[VAL_6:.*6]]: index, -// CHECK-SAME: %[[VAL_7:.*7]]: index) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)>) { -// CHECK: %[[VAL_8:.*]]:5 = scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_6]] step %[[VAL_7]] iter_args( -// CHECK-SAME: %[[VAL_10:.*]] = %[[VAL_0]], +// CHECK-SAME: %[[VAL_1:.*0]]: memref, +// CHECK-SAME: %[[VAL_2:.*1]]: memref, +// CHECK-SAME: %[[VAL_3:.*2]]: memref, +// CHECK-SAME: %[[VAL_4:.*3]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[VAL_5:.*4]]: index, +// CHECK-SAME: %[[VAL_6:.*5]]: index, +// CHECK-SAME: %[[VAL_7:.*6]]: index) -> (memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)>) { +// CHECK: %[[VAL_8:.*]]:4 = scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_6]] step %[[VAL_7]] iter_args( // CHECK-SAME: %[[VAL_11:.*]] = %[[VAL_1]], // CHECK-SAME: %[[VAL_12:.*]] = %[[VAL_2]], // CHECK-SAME: %[[VAL_13:.*]] = %[[VAL_3]], // CHECK-SAME: %[[VAL_14:.*]] = %[[VAL_4]]) -// CHECK: scf.yield %[[VAL_10]], %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]] : +// CHECK: scf.yield %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]] : // CHECK: } -// CHECK: return %[[VAL_8]]#0, %[[VAL_8]]#1, %[[VAL_8]]#2, %[[VAL_8]]#3, %[[VAL_8]]#4 : -// CHECK: } +// CHECK: return %[[VAL_8]]#0, %[[VAL_8]]#1, %[[VAL_8]]#2, %[[VAL_8]]#3 func.func @for(%in: tensor<1024xf32, #SparseVector>, %lb: index, %ub: index, %step: index) -> tensor<1024xf32, #SparseVector> { %1 = scf.for %i = %lb to %ub step %step iter_args(%vin = %in) @@ -31,24 +28,22 @@ } // CHECK-LABEL: func.func @if( -// CHECK-SAME: %[[VAL_0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[VAL_1:.*1]]: memref, -// CHECK-SAME: %[[VAL_2:.*2]]: memref, -// CHECK-SAME: %[[VAL_3:.*3]]: memref, -// CHECK-SAME: %[[VAL_4:.*4]]: !llvm.struct<(array<1 x i64>)>, -// CHECK-SAME: %[[VAL_5:.*5]]: memref<3xindex>, -// CHECK-SAME: %[[VAL_6:.*6]]: memref, -// CHECK-SAME: %[[VAL_7:.*7]]: memref, -// CHECK-SAME: %[[VAL_8:.*8]]: memref, -// CHECK-SAME: %[[VAL_9:.*9]]: !llvm.struct<(array<1 x i64>)>, +// CHECK-SAME: %[[VAL_1:.*0]]: memref, +// CHECK-SAME: %[[VAL_2:.*1]]: memref, +// CHECK-SAME: %[[VAL_3:.*2]]: memref, +// CHECK-SAME: %[[VAL_4:.*3]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[VAL_6:.*4]]: memref, +// CHECK-SAME: %[[VAL_7:.*5]]: memref, +// CHECK-SAME: %[[VAL_8:.*6]]: memref, +// CHECK-SAME: %[[VAL_9:.*7]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, // CHECK-SAME: %[[VAL_10:.*]]: i1) -// CHECK: %[[VAL_11:.*]]:5 = scf.if %[[VAL_10]] -// CHECK: scf.yield %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] +// CHECK: %[[VAL_11:.*]]:4 = scf.if %[[VAL_10]] +// CHECK: scf.yield %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] // CHECK: } else { -// CHECK: scf.yield %[[VAL_5]], %[[VAL_6]], %[[VAL_7]], %[[VAL_8]], %[[VAL_9]] +// CHECK: scf.yield %[[VAL_6]], %[[VAL_7]], %[[VAL_8]], %[[VAL_9]] // CHECK: } -// CHECK: return %[[VAL_11]]#0, %[[VAL_11]]#1, %[[VAL_11]]#2, %[[VAL_11]]#3, %[[VAL_11]]#4 : -// CHECK-SAME: memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)> +// CHECK: return %[[VAL_11]]#0, %[[VAL_11]]#1, %[[VAL_11]]#2, %[[VAL_11]]#3 : +// CHECK-SAME: memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)> func.func @if(%t: tensor<1024xf32, #SparseVector>, %f: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024xf32, #SparseVector> { @@ -62,29 +57,26 @@ // CHECK-LABEL: func.func @while( -// CHECK-SAME: %[[VAL_0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[VAL_1:.*1]]: memref, -// CHECK-SAME: %[[VAL_2:.*2]]: memref, -// CHECK-SAME: %[[VAL_3:.*3]]: memref, -// CHECK-SAME: %[[VAL_4:.*4]]: !llvm.struct<(array<1 x i64>)>, -// CHECK-SAME: %[[VAL_5:.*5]]: i1) -// CHECK: %[[VAL_6:.*]]:5 = scf.while ( -// CHECK-SAME: %[[VAL_7:.*]] = %[[VAL_0]], +// CHECK-SAME: %[[VAL_1:.*0]]: memref, +// CHECK-SAME: %[[VAL_2:.*1]]: memref, +// CHECK-SAME: %[[VAL_3:.*2]]: memref, +// CHECK-SAME: %[[VAL_4:.*3]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[VAL_5:.*4]]: i1) +// CHECK: %[[VAL_6:.*]]:4 = scf.while ( // CHECK-SAME: %[[VAL_8:.*]] = %[[VAL_1]], // CHECK-SAME: %[[VAL_9:.*]] = %[[VAL_2]], // CHECK-SAME: %[[VAL_10:.*]] = %[[VAL_3]], // CHECK-SAME: %[[VAL_11:.*]] = %[[VAL_4]]) -// CHECK: scf.condition(%[[VAL_5]]) %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]], %[[VAL_11]] +// CHECK: scf.condition(%[[VAL_5]]) %[[VAL_8]], %[[VAL_9]], %[[VAL_10]], %[[VAL_11]] // CHECK: } do { -// CHECK: ^bb0(%[[VAL_12:.*6]]: memref<3xindex>, -// CHECK-SAME: %[[VAL_13:.*7]]: memref, -// CHECK-SAME: %[[VAL_14:.*8]]: memref, -// CHECK-SAME: %[[VAL_15:.*9]]: memref, -// CHECK-SAME: %[[VAL_16:.*10]]: !llvm.struct<(array<1 x i64>)>): -// CHECK: scf.yield %[[VAL_12]], %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] +// CHECK: ^bb0(%[[VAL_13:.*5]]: memref, +// CHECK-SAME: %[[VAL_14:.*6]]: memref, +// CHECK-SAME: %[[VAL_15:.*7]]: memref, +// CHECK-SAME: %[[VAL_16:.*8]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>): +// CHECK: scf.yield %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] // CHECK: } -// CHECK: return %[[VAL_6]]#0, %[[VAL_6]]#1, %[[VAL_6]]#2, %[[VAL_6]]#3, %[[VAL_6]]#4 : -// CHECK-SAME: memref<3xindex>, memref, memref, memref, !llvm.struct<(array<1 x i64>)> +// CHECK: return %[[VAL_6]]#0, %[[VAL_6]]#1, %[[VAL_6]]#2, %[[VAL_6]]#3 : +// CHECK-SAME: memref, memref, memref, !llvm.struct<(array<1 x i64>, array<3 x i64>)> func.func @while(%arg0: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024xf32, #SparseVector> { %0 = scf.while (%in = %arg0) : (tensor<1024xf32, #SparseVector>) -> tensor<1024xf32, #SparseVector> { scf.condition(%c) %in : tensor<1024xf32, #SparseVector> diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir @@ -13,136 +13,147 @@ // Computes C = A x B with all matrices sparse (SpMSpM) in CSR. // // CHECK-LABEL: func.func private @_insert_D_C_4_4_f64_0_0( -// CHECK-SAME: %[[VAL_0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[VAL_0:.*0]]: memref, // CHECK-SAME: %[[VAL_1:.*1]]: memref, -// CHECK-SAME: %[[VAL_2:.*2]]: memref, -// CHECK-SAME: %[[VAL_3:.*3]]: memref, -// CHECK-SAME: %[[VAL_4:.*4]]: !llvm.struct<(array<2 x i64>)>, +// CHECK-SAME: %[[VAL_2:.*2]]: memref, +// CHECK-SAME: %[[VAL_3:.*3]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[VAL_4:.*4]]: index, // CHECK-SAME: %[[VAL_5:.*5]]: index, -// CHECK-SAME: %[[VAL_6:.*6]]: index, -// CHECK-SAME: %[[VAL_7:.*7]]: f64) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)>) { -// CHECK: %[[VAL_8:.*]] = arith.constant false -// CHECK: %[[VAL_9:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_5]], %[[VAL_9]] : index -// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_5]]] : memref -// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_10]]] : memref -// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref<3xindex> -// CHECK: %[[VAL_14:.*]] = arith.subi %[[VAL_12]], %[[VAL_9]] : index -// CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_12]] : index +// CHECK-SAME: %[[VAL_6:.*6]]: f64) +// CHECK: %[[VAL_7:.*]] = arith.constant false +// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_4]], %[[VAL_8]] : index +// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref +// CHECK: %[[VAL_12:.*]] = llvm.extractvalue %[[VAL_3]][1, 1] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : i64 to index +// CHECK: %[[VAL_14:.*]] = arith.subi %[[VAL_11]], %[[VAL_8]] : index +// CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_10]], %[[VAL_11]] : index // CHECK: %[[VAL_16:.*]] = scf.if %[[VAL_15]] -> (i1) { -// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_14]]] : memref -// CHECK: %[[VAL_18:.*]] = arith.cmpi eq, %[[VAL_17]], %[[VAL_6]] : index +// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_14]]] : memref +// CHECK: %[[VAL_18:.*]] = arith.cmpi eq, %[[VAL_17]], %[[VAL_5]] : index // CHECK: scf.yield %[[VAL_18]] : i1 // CHECK: } else { -// CHECK: memref.store %[[VAL_13]], %[[VAL_1]]{{\[}}%[[VAL_5]]] : memref -// CHECK: scf.yield %[[VAL_8]] : i1 +// CHECK: memref.store %[[VAL_13]], %[[VAL_0]]{{\[}}%[[VAL_4]]] : memref +// CHECK: scf.yield %[[VAL_7]] : i1 // CHECK: } -// CHECK: %[[VAL_19:.*]] = scf.if %[[VAL_20:.*]] -> (memref) { -// CHECK: scf.yield %[[VAL_2]] : memref +// CHECK: %[[VAL_19:.*]]:2 = scf.if %[[VAL_20:.*]] -> (memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)>) { +// CHECK: scf.yield %[[VAL_1]], %[[VAL_3]] : memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)> // CHECK: } else { -// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_13]], %[[VAL_9]] : index -// CHECK: memref.store %[[VAL_21]], %[[VAL_1]]{{\[}}%[[VAL_10]]] : memref -// CHECK: %[[VAL_22:.*]] = sparse_tensor.push_back %[[VAL_0]], %[[VAL_2]], %[[VAL_6]] {idx = 1 : index} : memref<3xindex>, memref, index -// CHECK: scf.yield %[[VAL_22]] : memref +// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_13]], %[[VAL_8]] : index +// CHECK: memref.store %[[VAL_21]], %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref +// CHECK: %[[VAL_22:.*]], %[[VAL_23:.*]] = sparse_tensor.push_back %[[VAL_13]], %[[VAL_1]], %[[VAL_5]] : index, memref, index +// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : index to i64 +// CHECK: %[[VAL_25:.*]] = llvm.insertvalue %[[VAL_24]], %[[VAL_3]][1, 1] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: scf.yield %[[VAL_22]], %[[VAL_25]] : memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)> // CHECK: } -// CHECK: %[[VAL_23:.*]] = sparse_tensor.push_back %[[VAL_0]], %[[VAL_3]], %[[VAL_7]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_24:.*]], %[[VAL_23]], %[[VAL_4]] : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)> +// CHECK: %[[VAL_26:.*]] = llvm.extractvalue %[[VAL_27:.*]]#1[1, 2] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_28:.*]] = arith.index_cast %[[VAL_26]] : i64 to index +// CHECK: %[[VAL_29:.*]], %[[VAL_30:.*]] = sparse_tensor.push_back %[[VAL_28]], %[[VAL_2]], %[[VAL_6]] : index, memref, f64 +// CHECK: %[[VAL_31:.*]] = arith.index_cast %[[VAL_30]] : index to i64 +// CHECK: %[[VAL_32:.*]] = llvm.insertvalue %[[VAL_31]], %[[VAL_27]]#1[1, 2] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: return %[[VAL_0]], %[[VAL_27]]#0, %[[VAL_29]], %[[VAL_32]] : memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)> // CHECK: } // CHECK-LABEL: func.func @matmul( -// CHECK-SAME: %[[VAL_0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[VAL_0:.*0]]: memref, // CHECK-SAME: %[[VAL_1:.*1]]: memref, -// CHECK-SAME: %[[VAL_2:.*2]]: memref, -// CHECK-SAME: %[[VAL_3:.*3]]: memref, -// CHECK-SAME: %[[VAL_4:.*4]]: !llvm.struct<(array<2 x i64>)>, -// CHECK-SAME: %[[VAL_5:.*5]]: memref<3xindex>, -// CHECK-SAME: %[[VAL_6:.*6]]: memref, -// CHECK-SAME: %[[VAL_7:.*7]]: memref, -// CHECK-SAME: %[[VAL_8:.*8]]: memref, -// CHECK-SAME: %[[VAL_9:.*9]]: !llvm.struct<(array<2 x i64>)>) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)>) { -// CHECK: %[[VAL_10:.*]] = arith.constant 4 : index -// CHECK: %[[VAL_11:.*]] = arith.constant 4 : i64 -// CHECK: %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK: %[[VAL_13:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_14:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_15:.*]] = arith.constant false -// CHECK: %[[VAL_16:.*]] = arith.constant true -// CHECK: %[[VAL_17:.*]] = memref.alloc() : memref<3xindex> +// CHECK-SAME: %[[VAL_2:.*2]]: memref, +// CHECK-SAME: %[[VAL_3:.*3]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>, +// CHECK-SAME: %[[VAL_4:.*4]]: memref, +// CHECK-SAME: %[[VAL_5:.*5]]: memref, +// CHECK-SAME: %[[VAL_6:.*6]]: memref, +// CHECK-SAME: %[[VAL_7:.*7]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>) +// CHECK: %[[VAL_8:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_10:.*]] = arith.constant 4 : i64 +// CHECK: %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[VAL_12:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_13:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_14:.*]] = arith.constant false +// CHECK: %[[VAL_15:.*]] = arith.constant true +// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[VAL_17:.*]] = memref.cast %[[VAL_16]] : memref<16xindex> to memref // CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<16xindex> // CHECK: %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<16xindex> to memref -// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<16xindex> -// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<16xindex> to memref -// CHECK: %[[VAL_22:.*]] = memref.alloc() : memref<16xf64> -// CHECK: %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<16xf64> to memref -// CHECK: %[[VAL_24:.*]] = llvm.mlir.undef : !llvm.struct<(array<2 x i64>)> -// CHECK: linalg.fill ins(%[[VAL_13]] : index) outs(%[[VAL_17]] : memref<3xindex>) -// CHECK: %[[VAL_25:.*]] = llvm.insertvalue %[[VAL_11]], %[[VAL_24]][0, 0] : !llvm.struct<(array<2 x i64>)> -// CHECK: %[[VAL_26:.*]] = llvm.insertvalue %[[VAL_11]], %[[VAL_25]][0, 1] : !llvm.struct<(array<2 x i64>)> -// CHECK: %[[VAL_27:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_19]], %[[VAL_13]] {idx = 0 : index} : memref<3xindex>, memref, index -// CHECK: %[[VAL_28:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_27]], %[[VAL_13]], %[[VAL_10]] {idx = 0 : index} : memref<3xindex>, memref, index, index -// CHECK: %[[VAL_29:.*]] = memref.alloc() : memref<4xf64> -// CHECK: %[[VAL_30:.*]] = memref.alloc() : memref<4xi1> -// CHECK: %[[VAL_31:.*]] = memref.alloc() : memref<4xindex> -// CHECK: %[[VAL_32:.*]] = memref.cast %[[VAL_31]] : memref<4xindex> to memref -// CHECK: linalg.fill ins(%[[VAL_12]] : f64) outs(%[[VAL_29]] : memref<4xf64>) -// CHECK: linalg.fill ins(%[[VAL_15]] : i1) outs(%[[VAL_30]] : memref<4xi1>) -// CHECK: %[[VAL_33:.*]]:5 = scf.for %[[VAL_34:.*]] = %[[VAL_13]] to %[[VAL_10]] step %[[VAL_14]] iter_args(%[[VAL_35:.*]] = %[[VAL_17]], %[[VAL_36:.*]] = %[[VAL_28]], %[[VAL_37:.*]] = %[[VAL_21]], %[[VAL_38:.*]] = %[[VAL_23]], %[[VAL_39:.*]] = %[[VAL_26]]) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)>) { -// CHECK: %[[VAL_40:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_34]]] : memref -// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_34]], %[[VAL_14]] : index -// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_41]]] : memref -// CHECK: %[[VAL_43:.*]] = scf.for %[[VAL_44:.*]] = %[[VAL_40]] to %[[VAL_42]] step %[[VAL_14]] iter_args(%[[VAL_45:.*]] = %[[VAL_13]]) -> (index) { -// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_44]]] : memref -// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_44]]] : memref -// CHECK: %[[VAL_48:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_46]]] : memref -// CHECK: %[[VAL_49:.*]] = arith.addi %[[VAL_46]], %[[VAL_14]] : index -// CHECK: %[[VAL_50:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_49]]] : memref -// CHECK: %[[VAL_51:.*]] = scf.for %[[VAL_52:.*]] = %[[VAL_48]] to %[[VAL_50]] step %[[VAL_14]] iter_args(%[[VAL_53:.*]] = %[[VAL_45]]) -> (index) { -// CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_52]]] : memref -// CHECK: %[[VAL_55:.*]] = memref.load %[[VAL_29]]{{\[}}%[[VAL_54]]] : memref<4xf64> -// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_52]]] : memref -// CHECK: %[[VAL_57:.*]] = arith.mulf %[[VAL_47]], %[[VAL_56]] : f64 -// CHECK: %[[VAL_58:.*]] = arith.addf %[[VAL_55]], %[[VAL_57]] : f64 -// CHECK: %[[VAL_59:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_54]]] : memref<4xi1> -// CHECK: %[[VAL_60:.*]] = arith.cmpi eq, %[[VAL_59]], %[[VAL_15]] : i1 -// CHECK: %[[VAL_61:.*]] = scf.if %[[VAL_60]] -> (index) { -// CHECK: memref.store %[[VAL_16]], %[[VAL_30]]{{\[}}%[[VAL_54]]] : memref<4xi1> -// CHECK: memref.store %[[VAL_54]], %[[VAL_31]]{{\[}}%[[VAL_53]]] : memref<4xindex> -// CHECK: %[[VAL_62:.*]] = arith.addi %[[VAL_53]], %[[VAL_14]] : index -// CHECK: scf.yield %[[VAL_62]] : index +// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<16xf64> +// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<16xf64> to memref +// CHECK: %[[VAL_22:.*]] = llvm.mlir.undef : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_23:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_22]][1, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_24:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_23]][1, 1] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_25:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_24]][1, 2] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_26:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_25]][0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_27:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_26]][0, 1] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_28:.*]], %[[VAL_29:.*]] = sparse_tensor.push_back %[[VAL_9]], %[[VAL_17]], %[[VAL_9]] : index, memref, index +// CHECK: %[[VAL_30:.*]] = arith.index_cast %[[VAL_29]] : index to i64 +// CHECK: %[[VAL_31:.*]] = llvm.insertvalue %[[VAL_30]], %[[VAL_27]][1, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_32:.*]], %[[VAL_33:.*]] = sparse_tensor.push_back %[[VAL_29]], %[[VAL_28]], %[[VAL_9]], %[[VAL_8]] : index, memref, index, index +// CHECK: %[[VAL_34:.*]] = arith.index_cast %[[VAL_33]] : index to i64 +// CHECK: %[[VAL_35:.*]] = llvm.insertvalue %[[VAL_34]], %[[VAL_31]][1, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_36:.*]] = memref.alloc() : memref<4xf64> +// CHECK: %[[VAL_37:.*]] = memref.alloc() : memref<4xi1> +// CHECK: %[[VAL_38:.*]] = memref.alloc() : memref<4xindex> +// CHECK: %[[VAL_39:.*]] = memref.cast %[[VAL_38]] : memref<4xindex> to memref +// CHECK: linalg.fill ins(%[[VAL_11]] : f64) outs(%[[VAL_36]] : memref<4xf64>) +// CHECK: linalg.fill ins(%[[VAL_14]] : i1) outs(%[[VAL_37]] : memref<4xi1>) +// CHECK: %[[VAL_40:.*]]:4 = scf.for %[[VAL_41:.*]] = %[[VAL_9]] to %[[VAL_8]] step %[[VAL_13]] iter_args(%[[VAL_42:.*]] = %[[VAL_32]], %[[VAL_43:.*]] = %[[VAL_19]], %[[VAL_44:.*]] = %[[VAL_21]], %[[VAL_45:.*]] = %[[VAL_35]]) -> (memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)>) { +// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_41]]] : memref +// CHECK: %[[VAL_47:.*]] = arith.addi %[[VAL_41]], %[[VAL_13]] : index +// CHECK: %[[VAL_48:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_47]]] : memref +// CHECK: %[[VAL_49:.*]] = scf.for %[[VAL_50:.*]] = %[[VAL_46]] to %[[VAL_48]] step %[[VAL_13]] iter_args(%[[VAL_51:.*]] = %[[VAL_9]]) -> (index) { +// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_50]]] : memref +// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_50]]] : memref +// CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_52]]] : memref +// CHECK: %[[VAL_55:.*]] = arith.addi %[[VAL_52]], %[[VAL_13]] : index +// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_55]]] : memref +// CHECK: %[[VAL_57:.*]] = scf.for %[[VAL_58:.*]] = %[[VAL_54]] to %[[VAL_56]] step %[[VAL_13]] iter_args(%[[VAL_59:.*]] = %[[VAL_51]]) -> (index) { +// CHECK: %[[VAL_60:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_61:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_60]]] : memref<4xf64> +// CHECK: %[[VAL_62:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_63:.*]] = arith.mulf %[[VAL_53]], %[[VAL_62]] : f64 +// CHECK: %[[VAL_64:.*]] = arith.addf %[[VAL_61]], %[[VAL_63]] : f64 +// CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_37]]{{\[}}%[[VAL_60]]] : memref<4xi1> +// CHECK: %[[VAL_66:.*]] = arith.cmpi eq, %[[VAL_65]], %[[VAL_14]] : i1 +// CHECK: %[[VAL_67:.*]] = scf.if %[[VAL_66]] -> (index) { +// CHECK: memref.store %[[VAL_15]], %[[VAL_37]]{{\[}}%[[VAL_60]]] : memref<4xi1> +// CHECK: memref.store %[[VAL_60]], %[[VAL_38]]{{\[}}%[[VAL_59]]] : memref<4xindex> +// CHECK: %[[VAL_68:.*]] = arith.addi %[[VAL_59]], %[[VAL_13]] : index +// CHECK: scf.yield %[[VAL_68]] : index // CHECK: } else { -// CHECK: scf.yield %[[VAL_53]] : index +// CHECK: scf.yield %[[VAL_59]] : index // CHECK: } -// CHECK: memref.store %[[VAL_58]], %[[VAL_29]]{{\[}}%[[VAL_54]]] : memref<4xf64> -// CHECK: scf.yield %[[VAL_63:.*]] : index +// CHECK: memref.store %[[VAL_64]], %[[VAL_36]]{{\[}}%[[VAL_60]]] : memref<4xf64> +// CHECK: scf.yield %[[VAL_69:.*]] : index // CHECK: } -// CHECK: scf.yield %[[VAL_64:.*]] : index +// CHECK: scf.yield %[[VAL_70:.*]] : index // CHECK: } -// CHECK: sparse_tensor.sort %[[VAL_65:.*]], %[[VAL_32]] : memref -// CHECK: %[[VAL_66:.*]]:5 = scf.for %[[VAL_67:.*]] = %[[VAL_13]] to %[[VAL_65]] step %[[VAL_14]] iter_args(%[[VAL_68:.*]] = %[[VAL_35]], %[[VAL_69:.*]] = %[[VAL_36]], %[[VAL_70:.*]] = %[[VAL_37]], %[[VAL_71:.*]] = %[[VAL_38]], %[[VAL_72:.*]] = %[[VAL_39]]) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)>) { -// CHECK: %[[VAL_73:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_67]]] : memref<4xindex> -// CHECK: %[[VAL_74:.*]] = memref.load %[[VAL_29]]{{\[}}%[[VAL_73]]] : memref<4xf64> -// CHECK: %[[VAL_75:.*]]:5 = func.call @_insert_D_C_4_4_f64_0_0(%[[VAL_68]], %[[VAL_69]], %[[VAL_70]], %[[VAL_71]], %[[VAL_72]], %[[VAL_34]], %[[VAL_73]], %[[VAL_74]]) : (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)>, index, index, f64) -> (memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)>) -// CHECK: memref.store %[[VAL_12]], %[[VAL_29]]{{\[}}%[[VAL_73]]] : memref<4xf64> -// CHECK: memref.store %[[VAL_15]], %[[VAL_30]]{{\[}}%[[VAL_73]]] : memref<4xi1> -// CHECK: scf.yield %[[VAL_75]]#0, %[[VAL_75]]#1, %[[VAL_75]]#2, %[[VAL_75]]#3, %[[VAL_75]]#4 : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)> +// CHECK: sparse_tensor.sort %[[VAL_71:.*]], %[[VAL_39]] : memref +// CHECK: %[[VAL_72:.*]]:4 = scf.for %[[VAL_73:.*]] = %[[VAL_9]] to %[[VAL_71]] step %[[VAL_13]] iter_args(%[[VAL_74:.*]] = %[[VAL_42]], %[[VAL_75:.*]] = %[[VAL_43]], %[[VAL_76:.*]] = %[[VAL_44]], %[[VAL_77:.*]] = %[[VAL_45]]) -> (memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)>) { +// CHECK: %[[VAL_78:.*]] = memref.load %[[VAL_38]]{{\[}}%[[VAL_73]]] : memref<4xindex> +// CHECK: %[[VAL_79:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_78]]] : memref<4xf64> +// CHECK: %[[VAL_80:.*]]:4 = func.call @_insert_D_C_4_4_f64_0_0(%[[VAL_74]], %[[VAL_75]], %[[VAL_76]], %[[VAL_77]], %[[VAL_41]], %[[VAL_78]], %[[VAL_79]]) : (memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)>, index, index, f64) -> (memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)>) +// CHECK: memref.store %[[VAL_11]], %[[VAL_36]]{{\[}}%[[VAL_78]]] : memref<4xf64> +// CHECK: memref.store %[[VAL_14]], %[[VAL_37]]{{\[}}%[[VAL_78]]] : memref<4xi1> +// CHECK: scf.yield %[[VAL_80]]#0, %[[VAL_80]]#1, %[[VAL_80]]#2, %[[VAL_80]]#3 : memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)> // CHECK: } -// CHECK: scf.yield %[[VAL_76:.*]]#0, %[[VAL_76]]#1, %[[VAL_76]]#2, %[[VAL_76]]#3, %[[VAL_76]]#4 : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)> +// CHECK: scf.yield %[[VAL_81:.*]]#0, %[[VAL_81]]#1, %[[VAL_81]]#2, %[[VAL_81]]#3 : memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)> // CHECK: } -// CHECK: memref.dealloc %[[VAL_29]] : memref<4xf64> -// CHECK: memref.dealloc %[[VAL_30]] : memref<4xi1> -// CHECK: memref.dealloc %[[VAL_31]] : memref<4xindex> -// CHECK: %[[VAL_77:.*]] = memref.load %[[VAL_78:.*]]#0{{\[}}%[[VAL_13]]] : memref<3xindex> -// CHECK: %[[VAL_79:.*]] = memref.load %[[VAL_78]]#1{{\[}}%[[VAL_13]]] : memref -// CHECK: %[[VAL_80:.*]] = scf.for %[[VAL_81:.*]] = %[[VAL_14]] to %[[VAL_77]] step %[[VAL_14]] iter_args(%[[VAL_82:.*]] = %[[VAL_79]]) -> (index) { -// CHECK: %[[VAL_83:.*]] = memref.load %[[VAL_78]]#1{{\[}}%[[VAL_81]]] : memref -// CHECK: %[[VAL_84:.*]] = arith.cmpi eq, %[[VAL_83]], %[[VAL_13]] : index -// CHECK: %[[VAL_85:.*]] = arith.select %[[VAL_84]], %[[VAL_82]], %[[VAL_83]] : index -// CHECK: scf.if %[[VAL_84]] { -// CHECK: memref.store %[[VAL_82]], %[[VAL_78]]#1{{\[}}%[[VAL_81]]] : memref +// CHECK: memref.dealloc %[[VAL_36]] : memref<4xf64> +// CHECK: memref.dealloc %[[VAL_37]] : memref<4xi1> +// CHECK: memref.dealloc %[[VAL_38]] : memref<4xindex> +// CHECK: %[[VAL_82:.*]] = llvm.extractvalue %[[VAL_83:.*]]#3[1, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_84:.*]] = arith.index_cast %[[VAL_82]] : i64 to index +// CHECK: %[[VAL_85:.*]] = memref.load %[[VAL_83]]#0{{\[}}%[[VAL_9]]] : memref +// CHECK: %[[VAL_86:.*]] = scf.for %[[VAL_87:.*]] = %[[VAL_13]] to %[[VAL_84]] step %[[VAL_13]] iter_args(%[[VAL_88:.*]] = %[[VAL_85]]) -> (index) { +// CHECK: %[[VAL_89:.*]] = memref.load %[[VAL_83]]#0{{\[}}%[[VAL_87]]] : memref +// CHECK: %[[VAL_90:.*]] = arith.cmpi eq, %[[VAL_89]], %[[VAL_9]] : index +// CHECK: %[[VAL_91:.*]] = arith.select %[[VAL_90]], %[[VAL_88]], %[[VAL_89]] : index +// CHECK: scf.if %[[VAL_90]] { +// CHECK: memref.store %[[VAL_88]], %[[VAL_83]]#0{{\[}}%[[VAL_87]]] : memref // CHECK: } -// CHECK: scf.yield %[[VAL_85]] : index +// CHECK: scf.yield %[[VAL_91]] : index // CHECK: } -// CHECK: return %[[VAL_78]]#0, %[[VAL_78]]#1, %[[VAL_78]]#2, %[[VAL_78]]#3, %[[VAL_78]]#4 : memref<3xindex>, memref, memref, memref, !llvm.struct<(array<2 x i64>)> +// CHECK: return %[[VAL_83]]#0, %[[VAL_83]]#1, %[[VAL_83]]#2, %[[VAL_83]]#3 : memref, memref, memref, !llvm.struct<(array<2 x i64>, array<3 x i64>)> // CHECK: } func.func @matmul(%A: tensor<4x8xf64, #CSR>, %B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> { diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir @@ -17,16 +17,15 @@ %buffer = memref.alloc(%c1) : memref memref.store %c0, %bufferSizes[%c0] : memref - %buffer2 = sparse_tensor.push_back %bufferSizes, %buffer, %d2 {idx=0 : index} : memref, memref, f32 - %buffer3 = sparse_tensor.push_back %bufferSizes, %buffer2, %d1, %c10 {idx=0 : index} : memref, memref, f32, index + %buffer2, %s0 = sparse_tensor.push_back %c0, %buffer, %d2 : index, memref, f32 + %buffer3, %s1 = sparse_tensor.push_back %s0, %buffer2, %d1, %c10 : index, memref, f32, index // CHECK: 16 %capacity = memref.dim %buffer3, %c0 : memref vector.print %capacity : index - // CHECK: ( 11 ) - %size = vector.transfer_read %bufferSizes[%c0], %c0: memref, vector<1xindex> - vector.print %size : vector<1xindex> + // CHECK: 11 + vector.print %s1 : index // CHECK ( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) %values = vector.transfer_read %buffer3[%c0], %d0: memref, vector<11xf32>