diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -365,7 +365,7 @@ Arguments<(ins SparseTensorStorageSpecifier:$specifier, SparseTensorStorageSpecifierKindAttr:$specifierKind, OptionalAttr:$dim)>, - Results<(outs AnyType:$result)> { + Results<(outs Index:$result)> { let summary = ""; let description = [{ Returns the requested field of the given storage_specifier. @@ -374,12 +374,12 @@ ```mlir %0 = sparse_tensor.storage_specifier.get %arg0 idx_mem_sz at 0 - : !sparse_tensor.storage_specifier<#COO> to i64 + : !sparse_tensor.storage_specifier<#COO> ``` }]; let assemblyFormat = "$specifier $specifierKind (`at` $dim^)? attr-dict `:` " - "qualified(type($specifier)) `to` type($result)"; + "qualified(type($specifier))"; let hasVerifier = 1; let hasFolder = 1; } @@ -389,7 +389,7 @@ Arguments<(ins SparseTensorStorageSpecifier:$specifier, SparseTensorStorageSpecifierKindAttr:$specifierKind, OptionalAttr:$dim, - AnyType:$value)>, + Index:$value)>, Results<(outs SparseTensorStorageSpecifier:$result)> { let summary = ""; let description = [{ @@ -400,12 +400,12 @@ ```mlir %0 = sparse_tensor.storage_specifier.set %arg0 idx_mem_sz at 0 with %new_sz - : i32, !sparse_tensor.storage_specifier<#COO> + : !sparse_tensor.storage_specifier<#COO> ``` }]; let assemblyFormat = "$specifier $specifierKind (`at` $dim^)? `with` $value attr-dict `:` " - "type($value) `,` qualified(type($result))"; + "qualified(type($result))"; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td @@ -65,13 +65,6 @@ }]> ]; - let extraClassDeclaration = [{ - // Get the integer type used to store memory and dimension sizes. - IntegerType getSizesType() const; - Type getFieldType(StorageSpecifierKind kind, std::optional dim) const; - Type getFieldType(StorageSpecifierKind kind, std::optional dim) const; - }]; - // We skipped the default builder that simply takes the input sparse tensor encoding // attribute since we need to normalize the dimension level type and remove unrelated // fields that are irrelavant to sparse tensor storage scheme. diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -571,7 +571,11 @@ enc.getContext(), dlts, AffineMap(), // dimOrdering (irrelavant to storage speicifer) AffineMap(), // highLvlOrdering (irrelavant to storage specifer) - enc.getPointerBitWidth(), enc.getIndexBitWidth(), + // Always use index for memSize, dimSize instead of reusing + // getBitwidth from pointers/indices. + // It allows us to reuse the same SSA value for different bitwidth, + // It also avoids casting between index/integer (returned by DimOp) + 0, 0, // FIXME: we should keep the slice information, for now it is okay as only // constant can be used for slice ArrayRef{} /*enc.getDimSlices()*/); @@ -582,36 +586,6 @@ return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding)); } -IntegerType StorageSpecifierType::getSizesType() const { - unsigned idxBitWidth = - getEncoding().getIndexBitWidth() ? getEncoding().getIndexBitWidth() : 64u; - unsigned ptrBitWidth = - getEncoding().getIndexBitWidth() ? getEncoding().getIndexBitWidth() : 64u; - - return IntegerType::get(getContext(), std::max(idxBitWidth, ptrBitWidth)); -} - -// FIXME: see note [CLARIFY_DIM_LVL] in -// "lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h" -Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind, - std::optional dim) const { - if (kind != StorageSpecifierKind::ValMemSize) - assert(dim); - - // Right now, we store every sizes metadata using the same size type. - // TODO: the field size type can be defined dimensional wise after sparse - // tensor encoding supports per dimension index/pointer bitwidth. - return getSizesType(); -} - -// FIXME: see note [CLARIFY_DIM_LVL] in -// "lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h" -Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind, - std::optional dim) const { - return getFieldType(kind, dim ? std::optional(dim.value().getZExtValue()) - : std::nullopt); -} - //===----------------------------------------------------------------------===// // SparseTensorDialect Operations. //===----------------------------------------------------------------------===// @@ -776,12 +750,6 @@ LogicalResult GetStorageSpecifierOp::verify() { RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter( getSpecifierKind(), getDim(), getSpecifier(), getOperation())) - // Checks the result type - if (getSpecifier().getType().getFieldType(getSpecifierKind(), getDim()) != - getResult().getType()) { - return emitError( - "type mismatch between requested specifier field and result value"); - } return success(); } @@ -802,12 +770,6 @@ LogicalResult SetStorageSpecifierOp::verify() { RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter( getSpecifierKind(), getDim(), getSpecifier(), getOperation())) - // Checks the input type - if (getSpecifier().getType().getFieldType(getSpecifierKind(), getDim()) != - getValue().getType()) { - return emitError( - "type mismatch between requested specifier field and input value"); - } return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -78,6 +78,9 @@ // Misc code generators and utilities. //===----------------------------------------------------------------------===// +/// Add type casting between arith and index types when needed. +Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy); + /// Generates a 1-valued attribute of the given type. This supports /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`, /// for unsupported types we raise `llvm_unreachable` rather than diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -204,6 +204,39 @@ // Misc code generators. //===----------------------------------------------------------------------===// +Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value, + Type dstTy) { + Type srcTy = value.getType(); + if (srcTy != dstTy) { + // int <=> index + if (dstTy.isa() || srcTy.isa()) + return builder.create(loc, dstTy, value); + + bool ext = srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth(); + + // float => float. + if (srcTy.isa() && dstTy.isa() && ext) + return builder.create(loc, dstTy, value); + + if (srcTy.isa() && dstTy.isa() && !ext) + return builder.create(loc, dstTy, value); + + // int => int + if (srcTy.isUnsignedInteger() && dstTy.isa() && ext) + return builder.create(loc, dstTy, value); + + if (srcTy.isSignedInteger() && dstTy.isa() && ext) + return builder.create(loc, dstTy, value); + + if (srcTy.isa() && dstTy.isa() && !ext) + return builder.create(loc, dstTy, value); + + llvm_unreachable("unhandled type casting"); + } + + return value; +} + mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { if (tp.isa()) return builder.getFloatAttr(tp, 1.0); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp @@ -27,7 +27,9 @@ const Level lvlRank = enc.getLvlRank(); SmallVector result; - auto indexType = tp.getSizesType(); + // TODO: how can we get the lowering type for index type in the later pipeline + // to be consistent? LLVM::StructureType does not allow index fields. + auto indexType = IntegerType::get(tp.getContext(), 64); auto dimSizes = LLVM::LLVMArrayType::get(ctx, indexType, lvlRank); auto memSizes = LLVM::LLVMArrayType::get(ctx, indexType, getNumDataFieldsFromEncoding(enc)); @@ -49,6 +51,21 @@ constexpr uint64_t kMemSizePosInSpecifier = 1; class SpecifierStructBuilder : public StructBuilder { +private: + Value extractField(OpBuilder &builder, Location loc, + ArrayRef indices) { + return genCast(builder, loc, + builder.create(loc, value, indices), + builder.getIndexType()); + } + + void insertField(OpBuilder &builder, Location loc, ArrayRef indices, + Value v) { + value = builder.create( + loc, value, genCast(builder, loc, v, builder.getIntegerType(64)), + indices); + } + public: explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) { assert(value); @@ -83,29 +100,30 @@ /// Builds IR inserting the pos-th size into the descriptor. Value SpecifierStructBuilder::dimSize(OpBuilder &builder, Location loc, unsigned dim) { - return builder.create( - loc, value, ArrayRef({kDimSizePosInSpecifier, dim})); + return extractField(builder, loc, + ArrayRef{kDimSizePosInSpecifier, dim}); } /// Builds IR inserting the pos-th size into the descriptor. void SpecifierStructBuilder::setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value size) { - value = builder.create( - loc, value, size, ArrayRef({kDimSizePosInSpecifier, dim})); + + insertField(builder, loc, ArrayRef{kDimSizePosInSpecifier, dim}, + size); } /// Builds IR extracting the pos-th memory size into the descriptor. Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc, unsigned pos) { - return builder.create( - loc, value, ArrayRef({kMemSizePosInSpecifier, pos})); + return extractField(builder, loc, + ArrayRef{kMemSizePosInSpecifier, pos}); } /// Builds IR inserting the pos-th memory size into the descriptor. void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc, unsigned pos, Value size) { - value = builder.create( - loc, value, size, ArrayRef({kMemSizePosInSpecifier, pos})); + insertField(builder, loc, ArrayRef{kMemSizePosInSpecifier, pos}, + size); } } // namespace diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -67,25 +67,18 @@ } } -/// Adds index conversions where needed. -static Value toType(OpBuilder &builder, Location loc, Value value, Type tp) { - if (value.getType() != tp) - return builder.create(loc, tp, value); - return value; -} - /// Generates a load with proper index typing. static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) { - idx = toType(builder, loc, idx, builder.getIndexType()); + idx = genCast(builder, loc, idx, builder.getIndexType()); return builder.create(loc, mem, idx); } /// Generates a store with proper index typing and (for indices) proper value. static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx) { - idx = toType(builder, loc, idx, builder.getIndexType()); - val = toType(builder, loc, val, - mem.getType().cast().getElementType()); + idx = genCast(builder, loc, idx, builder.getIndexType()); + val = genCast(builder, loc, val, + mem.getType().cast().getElementType()); builder.create(loc, val, mem, idx); } @@ -141,7 +134,7 @@ auto pushBackOp = builder.create( loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field, - toType(builder, loc, value, etp), repeat); + genCast(builder, loc, value, etp), repeat); desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer()); desc.setSpecifierField(builder, loc, specFieldKind, lvl, @@ -338,7 +331,7 @@ msz = builder.create(loc, msz, idxStrideC); } Value phim1 = builder.create( - loc, toType(builder, loc, phi, indexType), one); + loc, genCast(builder, loc, phi, indexType), one); // Conditional expression. Value lt = builder.create(loc, arith::CmpIPredicate::ult, plo, phi); @@ -350,9 +343,9 @@ builder, loc, desc.getMemRefField(idxIndex), idxStride > 1 ? builder.create(loc, phim1, idxStrideC) : phim1); - Value eq = builder.create(loc, arith::CmpIPredicate::eq, - toType(builder, loc, crd, indexType), - indices[lvl]); + Value eq = builder.create( + loc, arith::CmpIPredicate::eq, genCast(builder, loc, crd, indexType), + indices[lvl]); builder.create(loc, eq); builder.setInsertionPointToStart(&ifOp1.getElseRegion().front()); if (lvl > 0) @@ -1226,8 +1219,8 @@ // Converts MemRefs back to Tensors. Value data = rewriter.create(loc, dataBuf); Value indices = rewriter.create(loc, idxBuf); - Value nnz = toType(rewriter, loc, desc.getValMemSize(rewriter, loc), - op.getNnz().getType()); + Value nnz = genCast(rewriter, loc, desc.getValMemSize(rewriter, loc), + op.getNnz().getType()); rewriter.replaceOp(op, {data, indices, nnz}); return success(); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h @@ -241,11 +241,6 @@ StorageSpecifierKind kind, std::optional dim); - // FIXME: see note [CLARIFY_DIM_LVL]. - Type getFieldType(StorageSpecifierKind kind, std::optional dim) { - return specifier.getType().getFieldType(kind, dim); - } - private: TypedValue specifier; }; @@ -283,6 +278,8 @@ /// Getters: get the value for required field. /// + Value getSpecifier() const { return fields.back(); } + // FIXME: see note [CLARIFY_DIM_LVL]. Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp @@ -22,13 +22,6 @@ // Private helper methods. //===----------------------------------------------------------------------===// -static Value createIndexCast(OpBuilder &builder, Location loc, Value value, - Type to) { - if (value.getType() != to) - return builder.create(loc, to, value); - return value; -} - static IntegerAttr fromOptionalInt(MLIRContext *ctx, std::optional dim) { if (!dim) @@ -90,20 +83,17 @@ Value SparseTensorSpecifier::getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional dim) { - return createIndexCast(builder, loc, - builder.create( - loc, getFieldType(kind, dim), specifier, kind, - fromOptionalInt(specifier.getContext(), dim)), - builder.getIndexType()); + return builder.create( + loc, specifier, kind, fromOptionalInt(specifier.getContext(), dim)); } void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc, Value v, StorageSpecifierKind kind, std::optional dim) { + assert(v.getType().isIndex()); specifier = builder.create( - loc, specifier, kind, fromOptionalInt(specifier.getContext(), dim), - createIndexCast(builder, loc, v, getFieldType(kind, dim))); + loc, specifier, kind, fromOptionalInt(specifier.getContext(), dim), v); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -190,8 +190,7 @@ // CHECK-SAME: %[[A0:.*]]: memref, // CHECK-SAME: %[[A1:.*]]: !sparse_tensor.storage_specifier // CHECK: %[[A2:.*]] = sparse_tensor.storage_specifier.get %[[A1]] dim_sz at 2 -// CHECK: %[[A3:.*]] = arith.index_cast %[[A2]] : i64 to index -// CHECK: return %[[A3]] : index +// CHECK: return %[[A2]] : index func.func @sparse_dense_3d_dyn(%arg0: tensor) -> index { %c = arith.constant 1 : index %0 = tensor.dim %arg0, %c : tensor @@ -260,8 +259,7 @@ // CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[S0:.*]] = sparse_tensor.storage_specifier.get %[[A5]] idx_mem_sz at 1 -// CHECK: %[[S1:.*]] = arith.index_cast %[[S0]] -// CHECK: %[[S2:.*]] = arith.divui %[[S1]], %[[C2]] : index +// CHECK: %[[S2:.*]] = arith.divui %[[S0]], %[[C2]] : index // CHECK: %[[R1:.*]] = memref.subview %[[A3]][0] {{\[}}%[[S2]]] [2] : memref to memref> // CHECK: %[[R2:.*]] = memref.cast %[[R1]] : memref> to memref> // CHECK: return %[[R2]] : memref> @@ -288,8 +286,7 @@ // CHECK-SAME: %[[A1:.*]]: memref, // CHECK-SAME: %[[A2:.*]]: memref, // CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier -// CHECK: %[[A4:.*]] = sparse_tensor.storage_specifier.get %[[A3]] val_mem_sz -// CHECK: %[[NOE:.*]] = arith.index_cast %[[A4]] : i64 to index +// CHECK: %[[NOE:.*]] = sparse_tensor.storage_specifier.get %[[A3]] val_mem_sz // CHECK: return %[[NOE]] : index func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index { %0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector> @@ -312,8 +309,8 @@ // CHECK-LABEL: func.func @sparse_alloc_csc( // CHECK-SAME: %[[A0:.*]]: index) -> (memref, memref, memref, !sparse_tensor.storage_specifier -// CHECK: %[[A1:.*]] = arith.constant 10 : i64 -// CHECK: %[[A2:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[A1:.*]] = arith.constant 10 : index +// CHECK-DAG: %[[A2:.*]] = arith.constant 0 : index // CHECK: %[[A3:.*]] = memref.alloc() : memref<16xindex> // CHECK: %[[A4:.*]] = memref.cast %[[A3]] : memref<16xindex> to memref // CHECK: %[[A5:.*]] = memref.alloc() : memref<16xindex> @@ -321,17 +318,13 @@ // CHECK: %[[A7:.*]] = memref.alloc() : memref<16xf64> // CHECK: %[[A8:.*]] = memref.cast %[[A7]] : memref<16xf64> to memref // CHECK: %[[A9:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier -// CHECK: %[[A10:.*]] = arith.index_cast %[[A0]] : index to i64 -// CHECK: %[[A11:.*]] = sparse_tensor.storage_specifier.set %[[A9]] dim_sz at 0 with %[[A10]] : i64, !sparse_tensor.storage_specifier -// CHECK: %[[A12:.*]] = sparse_tensor.storage_specifier.set %[[A11]] dim_sz at 1 with %[[A1]] : i64, !sparse_tensor.storage_specifier -// CHECK: %[[A13:.*]] = sparse_tensor.storage_specifier.get %[[A12]] ptr_mem_sz at 1 : !sparse_tensor.storage_specifier -// CHECK: %[[A14:.*]] = arith.index_cast %[[A13]] : i64 to index -// CHECK: %[[A15:.*]], %[[A16:.*]] = sparse_tensor.push_back %[[A14]], %[[A4]], %[[A2]] : index, memref, index -// CHECK: %[[A17:.*]] = arith.index_cast %[[A16]] : index to i64 -// CHECK: %[[A18:.*]] = sparse_tensor.storage_specifier.set %[[A12]] ptr_mem_sz at 1 with %[[A17]] : i64, !sparse_tensor.storage_specifier -// CHECK: %[[A23:.*]], %[[A24:.*]] = sparse_tensor.push_back %[[A16]], %[[A15]], %[[A2]], %[[A0]] : index, memref, index, index -// CHECK: %[[A25:.*]] = arith.index_cast %[[A24]] : index to i64 -// CHECK: %[[A26:.*]] = sparse_tensor.storage_specifier.set %[[A18]] ptr_mem_sz at 1 with %[[A25]] : i64, !sparse_tensor.storage_specifier +// CHECK: %[[A11:.*]] = sparse_tensor.storage_specifier.set %[[A9]] dim_sz at 0 with %[[A0]] : !sparse_tensor.storage_specifier +// CHECK: %[[A12:.*]] = sparse_tensor.storage_specifier.set %[[A11]] dim_sz at 1 with %[[A1]] : !sparse_tensor.storage_specifier +// CHECK: %[[A14:.*]] = sparse_tensor.storage_specifier.get %[[A12]] ptr_mem_sz at 1 : !sparse_tensor.storage_specifier +// CHECK: %[[A15:.*]], %[[A17:.*]] = sparse_tensor.push_back %[[A14]], %[[A4]], %[[A2]] : index, memref, index +// CHECK: %[[A18:.*]] = sparse_tensor.storage_specifier.set %[[A12]] ptr_mem_sz at 1 with %[[A17]] : !sparse_tensor.storage_specifier +// CHECK: %[[A23:.*]], %[[A25:.*]] = sparse_tensor.push_back %[[A17]], %[[A15]], %[[A2]], %[[A0]] : index, memref, index, index +// CHECK: %[[A26:.*]] = sparse_tensor.storage_specifier.set %[[A18]] ptr_mem_sz at 1 with %[[A25]] : !sparse_tensor.storage_specifier // CHECK: return %[[A23]], %[[A6]], %[[A8]], %[[A26]] : memref, memref, memref, !sparse_tensor.storage_specifier func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> { %0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSC> @@ -340,23 +333,21 @@ } // CHECK-LABEL: func.func @sparse_alloc_3d() -> (memref, !sparse_tensor.storage_specifier -// CHECK: %[[A0:.*]] = arith.constant 6000 : index -// CHECK: %[[A1:.*]] = arith.constant 20 : i64 -// CHECK: %[[A2:.*]] = arith.constant 10 : i64 -// CHECK: %[[A3:.*]] = arith.constant 30 : i64 -// CHECK: %[[A4:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[A0:.*]] = arith.constant 6000 : index +// CHECK-DAG: %[[A1:.*]] = arith.constant 20 : index +// CHECK-DAG: %[[A2:.*]] = arith.constant 10 : index +// CHECK-DAG: %[[A3:.*]] = arith.constant 30 : index +// CHECK-DAG: %[[A4:.*]] = arith.constant 0.000000e+00 : f64 // CHECK: %[[A5:.*]] = memref.alloc() : memref<6000xf64> // CHECK: %[[A6:.*]] = memref.cast %[[A5]] : memref<6000xf64> to memref // CHECK: %[[A7:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier -// CHECK: %[[A8:.*]] = sparse_tensor.storage_specifier.set %[[A7]] dim_sz at 0 with %[[A3]] : i64, !sparse_tensor.storage_specifier -// CHECK: %[[A9:.*]] = sparse_tensor.storage_specifier.set %[[A8]] dim_sz at 1 with %[[A2]] : i64, !sparse_tensor.storage_specifier -// CHECK: %[[A10:.*]] = sparse_tensor.storage_specifier.set %[[A9]] dim_sz at 2 with %[[A1]] : i64, !sparse_tensor.storage_specifier -// CHECK: %[[A11:.*]] = sparse_tensor.storage_specifier.get %[[A10]] val_mem_sz : !sparse_tensor.storage_specifier -// CHECK: %[[A12:.*]] = arith.index_cast %[[A11]] : i64 to index -// CHECK: %[[A13:.*]], %[[A14:.*]] = sparse_tensor.push_back %[[A12]], %[[A6]], %[[A4]], %[[A0]] : index, memref, f64, index -// CHECK: %[[A15:.*]] = arith.index_cast %[[A14]] : index to i64 -// CHECK: %[[A16:.*]] = sparse_tensor.storage_specifier.set %[[A10]] val_mem_sz with %[[A15]] : i64, !sparse_tensor.storage_specifier -// CHECK: return %[[A13]], %[[A16]] : memref, !sparse_tensor.storage_specifier +// CHECK: %[[A8:.*]] = sparse_tensor.storage_specifier.set %[[A7]] dim_sz at 0 with %[[A3]] : !sparse_tensor.storage_specifier +// CHECK: %[[A9:.*]] = sparse_tensor.storage_specifier.set %[[A8]] dim_sz at 1 with %[[A2]] : !sparse_tensor.storage_specifier +// CHECK: %[[A10:.*]] = sparse_tensor.storage_specifier.set %[[A9]] dim_sz at 2 with %[[A1]] : !sparse_tensor.storage_specifier +// CHECK: %[[A12:.*]] = sparse_tensor.storage_specifier.get %[[A10]] val_mem_sz : !sparse_tensor.storage_specifier +// CHECK: %[[A15:.*]], %[[A14:.*]] = sparse_tensor.push_back %[[A12]], %[[A6]], %[[A4]], %[[A0]] : index, memref, f64, index +// CHECK: %[[A16:.*]] = sparse_tensor.storage_specifier.set %[[A10]] val_mem_sz with %[[A14]] : !sparse_tensor.storage_specifier +// CHECK: return %[[A15]], %[[A16]] : memref, !sparse_tensor.storage_specifier func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> { %0 = bufferization.alloc_tensor() : tensor<10x20x30xf64, #Dense3D> %1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D> @@ -503,8 +494,7 @@ // CHECK: memref.dealloc %[[A4]] : memref // CHECK: memref.dealloc %[[A5]] : memref // CHECK: memref.dealloc %[[A6]] : memref -// CHECK: %[[A23:.*]] = sparse_tensor.storage_specifier.get %[[A24:.*]]#3 ptr_mem_sz at 1 : !sparse_tensor.storage_specifier -// CHECK: %[[A25:.*]] = arith.index_cast %[[A23]] : i64 to index +// CHECK: %[[A25:.*]] = sparse_tensor.storage_specifier.get %[[A24:.*]]#3 ptr_mem_sz at 1 : !sparse_tensor.storage_specifier // CHECK: %[[A26:.*]] = memref.load %[[A24]]#0{{\[}}%[[A13]]] : memref // CHECK: %[[A27:.*]] = scf.for %[[A28:.*]] = %[[A12]] to %[[A25]] step %[[A12]] iter_args(%[[A29:.*]] = %[[A26]]) -> (i32) { // CHECK: %[[A30:.*]] = memref.load %[[A24]]#0{{\[}}%[[A28]]] : memref @@ -562,8 +552,7 @@ // CHECK: memref.dealloc %[[A4]] : memref // CHECK: memref.dealloc %[[A5]] : memref // CHECK: memref.dealloc %[[A6]] : memref -// CHECK: %[[A22:.*]] = sparse_tensor.storage_specifier.get %[[A23:.*]]#3 ptr_mem_sz at 1 : !sparse_tensor.storage_specifier -// CHECK: %[[A24:.*]] = arith.index_cast %[[A22]] : i64 to index +// CHECK: %[[A24:.*]] = sparse_tensor.storage_specifier.get %[[A23:.*]]#3 ptr_mem_sz at 1 : !sparse_tensor.storage_specifier // CHECK: %[[A25:.*]] = memref.load %[[A23]]#0{{\[}}%[[A11]]] : memref // CHECK: %[[A26:.*]] = scf.for %[[A27:.*]] = %[[A12]] to %[[A24]] step %[[A12]] iter_args(%[[A28:.*]] = %[[A25]]) -> (index) { // CHECK: %[[A29:.*]] = memref.load %[[A23]]#0{{\[}}%[[A27]]] : memref diff --git a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir --- a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir @@ -17,16 +17,12 @@ // CHECK: %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<16xf64> to memref // CHECK: linalg.fill ins(%[[VAL_2]] : f64) outs(%[[VAL_8]] : memref<16xf64>) // CHECK: %[[VAL_10:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier -// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_0]] : index to i64 -// CHECK: %[[VAL_12:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]] dim_sz at 0 with %[[VAL_11]] : i64, !sparse_tensor.storage_specifier -// CHECK: %[[VAL_13:.*]] = sparse_tensor.storage_specifier.get %[[VAL_12]] ptr_mem_sz at 0 : !sparse_tensor.storage_specifier -// CHECK: %[[VAL_14:.*]] = arith.index_cast %[[VAL_13]] : i64 to index -// CHECK: %[[VAL_15:.*]], %[[VAL_16:.*]] = sparse_tensor.push_back %[[VAL_14]], %[[VAL_5]], %[[VAL_3]] : index, memref, index -// CHECK: %[[VAL_17:.*]] = arith.index_cast %[[VAL_16]] : index to i64 -// CHECK: %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_12]] ptr_mem_sz at 0 with %[[VAL_17]] : i64, !sparse_tensor.storage_specifier -// CHECK: %[[VAL_19:.*]], %[[VAL_20:.*]] = sparse_tensor.push_back %[[VAL_16]], %[[VAL_15]], %[[VAL_3]], %[[VAL_1]] : index, memref, index, index -// CHECK: %[[VAL_21:.*]] = arith.index_cast %[[VAL_20]] : index to i64 -// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] ptr_mem_sz at 0 with %[[VAL_21]] : i64, !sparse_tensor.storage_specifier +// CHECK: %[[VAL_12:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]] dim_sz at 0 with %[[VAL_0]] : !sparse_tensor.storage_specifier +// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.get %[[VAL_12]] ptr_mem_sz at 0 : !sparse_tensor.storage_specifier +// CHECK: %[[VAL_15:.*]], %[[VAL_17:.*]] = sparse_tensor.push_back %[[VAL_14]], %[[VAL_5]], %[[VAL_3]] : index, memref, index +// CHECK: %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_12]] ptr_mem_sz at 0 with %[[VAL_17]] : !sparse_tensor.storage_specifier +// CHECK: %[[VAL_19:.*]], %[[VAL_21:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_15]], %[[VAL_3]], %[[VAL_1]] : index, memref, index, index +// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] ptr_mem_sz at 0 with %[[VAL_21]] : !sparse_tensor.storage_specifier // CHECK: return %[[VAL_19]], %[[VAL_7]], %[[VAL_9]], %[[VAL_22]] : memref, memref, memref, !sparse_tensor.storage_specifier func.func @sparse_alloc_sparse_vector(%arg0: index) -> tensor { %0 = bufferization.alloc_tensor(%arg0) : tensor diff --git a/mlir/test/Dialect/SparseTensor/fold.mlir b/mlir/test/Dialect/SparseTensor/fold.mlir --- a/mlir/test/Dialect/SparseTensor/fold.mlir +++ b/mlir/test/Dialect/SparseTensor/fold.mlir @@ -48,17 +48,17 @@ // CHECK-LABEL: func @sparse_get_specifier_dce_fold( // CHECK-SAME: %[[A0:.*]]: !sparse_tensor.storage_specifier -// CHECK-SAME: %[[A1:.*]]: i64, -// CHECK-SAME: %[[A2:.*]]: i64) +// CHECK-SAME: %[[A1:.*]]: index, +// CHECK-SAME: %[[A2:.*]]: index) // CHECK-NOT: sparse_tensor.storage_specifier.set // CHECK-NOT: sparse_tensor.storage_specifier.get // CHECK: return %[[A1]] -func.func @sparse_get_specifier_dce_fold(%arg0: !sparse_tensor.storage_specifier<#SparseVector>, %arg1: i64, %arg2: i64) -> i64 { +func.func @sparse_get_specifier_dce_fold(%arg0: !sparse_tensor.storage_specifier<#SparseVector>, %arg1: index, %arg2: index) -> index { %0 = sparse_tensor.storage_specifier.set %arg0 dim_sz at 0 with %arg1 - : i64, !sparse_tensor.storage_specifier<#SparseVector> + : !sparse_tensor.storage_specifier<#SparseVector> %1 = sparse_tensor.storage_specifier.set %0 ptr_mem_sz at 0 with %arg2 - : i64, !sparse_tensor.storage_specifier<#SparseVector> + : !sparse_tensor.storage_specifier<#SparseVector> %2 = sparse_tensor.storage_specifier.get %1 dim_sz at 0 - : !sparse_tensor.storage_specifier<#SparseVector> to i64 - return %2 : i64 + : !sparse_tensor.storage_specifier<#SparseVector> + return %2 : index } diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -252,68 +252,44 @@ #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> -func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 { +func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> index { // expected-error@+1 {{redundant level argument for querying value memory size}} %0 = sparse_tensor.storage_specifier.get %arg0 val_mem_sz at 0 - : !sparse_tensor.storage_specifier<#SparseVector> to i64 - return %0 : i64 + : !sparse_tensor.storage_specifier<#SparseVector> + return %0 : index } // ----- #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> -func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 { +func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> index { // expected-error@+1 {{missing level argument}} %0 = sparse_tensor.storage_specifier.get %arg0 idx_mem_sz - : !sparse_tensor.storage_specifier<#SparseVector> to i64 - return %0 : i64 + : !sparse_tensor.storage_specifier<#SparseVector> + return %0 : index } // ----- #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> -func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 { +func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> index { // expected-error@+1 {{requested level out of bound}} %0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 1 - : !sparse_tensor.storage_specifier<#SparseVector> to i64 - return %0 : i64 + : !sparse_tensor.storage_specifier<#SparseVector> + return %0 : index } // ----- #COO = #sparse_tensor.encoding<{dimLevelType = ["compressed-nu", "singleton"]}> -func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#COO>) -> i64 { +func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#COO>) -> index { // expected-error@+1 {{requested pointer memory size on a singleton level}} %0 = sparse_tensor.storage_specifier.get %arg0 ptr_mem_sz at 1 - : !sparse_tensor.storage_specifier<#COO> to i64 - return %0 : i64 -} - -// ----- - -#COO = #sparse_tensor.encoding<{dimLevelType = ["compressed-nu", "singleton"]}> - -func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#COO>) -> i64 { - // expected-error@+1 {{type mismatch between requested }} - %0 = sparse_tensor.storage_specifier.get %arg0 ptr_mem_sz at 0 - : !sparse_tensor.storage_specifier<#COO> to i32 - return %0 : i32 -} - -// ----- - -#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> - -func.func @sparse_set_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>, - %arg1: i32) - -> !sparse_tensor.storage_specifier<#SparseVector> { - // expected-error@+1 {{type mismatch between requested }} - %0 = sparse_tensor.storage_specifier.set %arg0 dim_sz at 0 with %arg1 - : i32, !sparse_tensor.storage_specifier<#SparseVector> - return %0 : !sparse_tensor.storage_specifier<#SparseVector> + : !sparse_tensor.storage_specifier<#COO> + return %0 : index } // ----- diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -184,11 +184,11 @@ // CHECK-LABEL: func @sparse_get_md( // CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}> // CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.get %[[A]] dim_sz at 0 -// CHECK: return %[[T]] : i64 -func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 { +// CHECK: return %[[T]] : index +func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> index { %0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 0 - : !sparse_tensor.storage_specifier<#SparseVector> to i64 - return %0 : i64 + : !sparse_tensor.storage_specifier<#SparseVector> + return %0 : index } // ----- @@ -197,13 +197,13 @@ // CHECK-LABEL: func @sparse_set_md( // CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>, -// CHECK-SAME: %[[I:.*]]: i64) +// CHECK-SAME: %[[I:.*]]: index) // CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.set %[[A]] dim_sz at 0 with %[[I]] // CHECK: return %[[T]] : !sparse_tensor.storage_specifier<#{{.*}}> -func.func @sparse_set_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>, %arg1: i64) +func.func @sparse_set_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>, %arg1: index) -> !sparse_tensor.storage_specifier<#SparseVector> { %0 = sparse_tensor.storage_specifier.set %arg0 dim_sz at 0 with %arg1 - : i64, !sparse_tensor.storage_specifier<#SparseVector> + : !sparse_tensor.storage_specifier<#SparseVector> return %0 : !sparse_tensor.storage_specifier<#SparseVector> } diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir @@ -25,8 +25,7 @@ // CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_4]], %[[VAL_8]] : index // CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_4]]] : memref // CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref -// CHECK: %[[VAL_12:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]] idx_mem_sz at 1 : !sparse_tensor.storage_specifier -// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : i64 to index +// CHECK: %[[VAL_13:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]] idx_mem_sz at 1 : !sparse_tensor.storage_specifier // CHECK: %[[VAL_14:.*]] = arith.subi %[[VAL_11]], %[[VAL_8]] : index // CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_10]], %[[VAL_11]] : index // CHECK: %[[VAL_16:.*]] = scf.if %[[VAL_15]] -> (i1) { @@ -42,16 +41,13 @@ // CHECK: } else { // CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_13]], %[[VAL_8]] : index // CHECK: memref.store %[[VAL_21]], %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref -// CHECK: %[[VAL_22:.*]], %[[VAL_23:.*]] = sparse_tensor.push_back %[[VAL_13]], %[[VAL_1]], %[[VAL_5]] : index, memref, index -// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : index to i64 -// CHECK: %[[VAL_25:.*]] = sparse_tensor.storage_specifier.set %[[VAL_3]] idx_mem_sz at 1 with %[[VAL_24]] : i64, !sparse_tensor.storage_specifier +// CHECK: %[[VAL_22:.*]], %[[VAL_24:.*]] = sparse_tensor.push_back %[[VAL_13]], %[[VAL_1]], %[[VAL_5]] : index, memref, index +// CHECK: %[[VAL_25:.*]] = sparse_tensor.storage_specifier.set %[[VAL_3]] idx_mem_sz at 1 with %[[VAL_24]] : !sparse_tensor.storage_specifier // CHECK: scf.yield %[[VAL_22]], %[[VAL_25]] : memref, !sparse_tensor.storage_specifier // CHECK: } -// CHECK: %[[VAL_26:.*]] = sparse_tensor.storage_specifier.get %[[VAL_27:.*]]#1 val_mem_sz : !sparse_tensor.storage_specifier -// CHECK: %[[VAL_28:.*]] = arith.index_cast %[[VAL_26]] : i64 to index +// CHECK: %[[VAL_28:.*]] = sparse_tensor.storage_specifier.get %[[VAL_27:.*]]#1 val_mem_sz : !sparse_tensor.storage_specifier // CHECK: %[[VAL_29:.*]], %[[VAL_30:.*]] = sparse_tensor.push_back %[[VAL_28]], %[[VAL_2]], %[[VAL_6]] : index, memref, f64 -// CHECK: %[[VAL_31:.*]] = arith.index_cast %[[VAL_30]] : index to i64 -// CHECK: %[[VAL_32:.*]] = sparse_tensor.storage_specifier.set %[[VAL_27]]#1 val_mem_sz with %[[VAL_31]] : i64, !sparse_tensor.storage_specifier +// CHECK: %[[VAL_32:.*]] = sparse_tensor.storage_specifier.set %[[VAL_27]]#1 val_mem_sz with %[[VAL_30]] : !sparse_tensor.storage_specifier // CHECK: return %[[VAL_0]], %[[VAL_27]]#0, %[[VAL_29]], %[[VAL_32]] : memref, memref, memref, !sparse_tensor.storage_specifier // CHECK: } @@ -64,94 +60,89 @@ // CHECK-SAME: %[[VAL_5:.*5]]: memref, // CHECK-SAME: %[[VAL_6:.*6]]: memref, // CHECK-SAME: %[[VAL_7:.*7]]: !sparse_tensor.storage_specifier -// CHECK: %[[VAL_8:.*]] = arith.constant 4 : index -// CHECK: %[[VAL_9:.*]] = arith.constant 4 : i64 -// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK: %[[VAL_11:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_12:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_13:.*]] = arith.constant false -// CHECK: %[[VAL_14:.*]] = arith.constant true -// CHECK: %[[VAL_15:.*]] = memref.alloc() : memref<16xindex> -// CHECK: %[[VAL_16:.*]] = memref.cast %[[VAL_15]] : memref<16xindex> to memref -// CHECK: %[[VAL_17:.*]] = memref.alloc() : memref<16xindex> -// CHECK: %[[VAL_18:.*]] = memref.cast %[[VAL_17]] : memref<16xindex> to memref -// CHECK: %[[VAL_19:.*]] = memref.alloc() : memref<16xf64> -// CHECK: %[[VAL_20:.*]] = memref.cast %[[VAL_19]] : memref<16xf64> to memref -// CHECK: %[[VAL_21:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier -// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_21]] dim_sz at 0 with %[[VAL_9]] : i64, !sparse_tensor.storage_specifier -// CHECK: %[[VAL_23:.*]] = sparse_tensor.storage_specifier.set %[[VAL_22]] dim_sz at 1 with %[[VAL_9]] : i64, !sparse_tensor.storage_specifier -// CHECK: %[[VAL_24:.*]] = sparse_tensor.storage_specifier.get %[[VAL_23]] ptr_mem_sz at 1 : !sparse_tensor.storage_specifier -// CHECK: %[[VAL_25:.*]] = arith.index_cast %[[VAL_24]] : i64 to index -// CHECK: %[[VAL_26:.*]], %[[VAL_27:.*]] = sparse_tensor.push_back %[[VAL_25]], %[[VAL_16]], %[[VAL_11]] : index, memref, index -// CHECK: %[[VAL_28:.*]] = arith.index_cast %[[VAL_27]] : index to i64 -// CHECK: %[[VAL_29:.*]] = sparse_tensor.storage_specifier.set %[[VAL_23]] ptr_mem_sz at 1 with %[[VAL_28]] : i64, !sparse_tensor.storage_specifier -// CHECK: %[[VAL_32:.*]], %[[VAL_33:.*]] = sparse_tensor.push_back %[[VAL_27]], %[[VAL_26]], %[[VAL_11]], %[[VAL_8]] : index, memref, index, index -// CHECK: %[[VAL_34:.*]] = arith.index_cast %[[VAL_33]] : index to i64 -// CHECK: %[[VAL_35:.*]] = sparse_tensor.storage_specifier.set %[[VAL_29]] ptr_mem_sz at 1 with %[[VAL_34]] : i64, !sparse_tensor.storage_specifier -// CHECK: %[[VAL_36:.*]] = memref.alloc() : memref<4xf64> -// CHECK: %[[VAL_37:.*]] = memref.alloc() : memref<4xi1> -// CHECK: %[[VAL_38:.*]] = memref.alloc() : memref<4xindex> -// CHECK: %[[VAL_39:.*]] = memref.cast %[[VAL_38]] : memref<4xindex> to memref -// CHECK: linalg.fill ins(%[[VAL_10]] : f64) outs(%[[VAL_36]] : memref<4xf64>) -// CHECK: linalg.fill ins(%[[VAL_13]] : i1) outs(%[[VAL_37]] : memref<4xi1>) -// CHECK: %[[VAL_40:.*]]:4 = scf.for %[[VAL_41:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_12]] iter_args(%[[VAL_42:.*]] = %[[VAL_32]], %[[VAL_43:.*]] = %[[VAL_18]], %[[VAL_44:.*]] = %[[VAL_20]], %[[VAL_45:.*]] = %[[VAL_35]]) -> (memref, memref, memref, !sparse_tensor.storage_specifier -// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_41]]] : memref -// CHECK: %[[VAL_47:.*]] = arith.addi %[[VAL_41]], %[[VAL_12]] : index -// CHECK: %[[VAL_48:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_47]]] : memref -// CHECK: %[[VAL_49:.*]] = scf.for %[[VAL_50:.*]] = %[[VAL_46]] to %[[VAL_48]] step %[[VAL_12]] iter_args(%[[VAL_51:.*]] = %[[VAL_11]]) -> (index) { -// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_50]]] : memref -// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_50]]] : memref -// CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_52]]] : memref -// CHECK: %[[VAL_55:.*]] = arith.addi %[[VAL_52]], %[[VAL_12]] : index -// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_55]]] : memref -// CHECK: %[[VAL_57:.*]] = scf.for %[[VAL_58:.*]] = %[[VAL_54]] to %[[VAL_56]] step %[[VAL_12]] iter_args(%[[VAL_59:.*]] = %[[VAL_51]]) -> (index) { -// CHECK: %[[VAL_60:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_58]]] : memref -// CHECK: %[[VAL_61:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_60]]] : memref<4xf64> -// CHECK: %[[VAL_62:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_58]]] : memref -// CHECK: %[[VAL_63:.*]] = arith.mulf %[[VAL_53]], %[[VAL_62]] : f64 -// CHECK: %[[VAL_64:.*]] = arith.addf %[[VAL_61]], %[[VAL_63]] : f64 -// CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_37]]{{\[}}%[[VAL_60]]] : memref<4xi1> -// CHECK: %[[VAL_66:.*]] = arith.cmpi eq, %[[VAL_65]], %[[VAL_13]] : i1 -// CHECK: %[[VAL_67:.*]] = scf.if %[[VAL_66]] -> (index) { -// CHECK: memref.store %[[VAL_14]], %[[VAL_37]]{{\[}}%[[VAL_60]]] : memref<4xi1> -// CHECK: memref.store %[[VAL_60]], %[[VAL_38]]{{\[}}%[[VAL_59]]] : memref<4xindex> -// CHECK: %[[VAL_68:.*]] = arith.addi %[[VAL_59]], %[[VAL_12]] : index -// CHECK: scf.yield %[[VAL_68]] : index +// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_12:.*]] = arith.constant false +// CHECK-DAG: %[[VAL_13:.*]] = arith.constant true +// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<16xindex> to memref +// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[VAL_17:.*]] = memref.cast %[[VAL_16]] : memref<16xindex> to memref +// CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<16xf64> +// CHECK: %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<16xf64> to memref +// CHECK: %[[VAL_20:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier +// CHECK: %[[VAL_21:.*]] = sparse_tensor.storage_specifier.set %[[VAL_20]] dim_sz at 0 with %[[VAL_8]] : !sparse_tensor.storage_specifier +// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_21]] dim_sz at 1 with %[[VAL_8]] : !sparse_tensor.storage_specifier +// CHECK: %[[VAL_23:.*]] = sparse_tensor.storage_specifier.get %[[VAL_22]] ptr_mem_sz at 1 : !sparse_tensor.storage_specifier +// CHECK: %[[VAL_24:.*]], %[[VAL_25:.*]] = sparse_tensor.push_back %[[VAL_23]], %[[VAL_15]], %[[VAL_10]] : index, memref, index +// CHECK: %[[VAL_26:.*]] = sparse_tensor.storage_specifier.set %[[VAL_22]] ptr_mem_sz at 1 with %[[VAL_25]] : !sparse_tensor.storage_specifier +// CHECK: %[[VAL_27:.*]], %[[VAL_28:.*]] = sparse_tensor.push_back %[[VAL_25]], %[[VAL_24]], %[[VAL_10]], %[[VAL_8]] : index, memref, index, index +// CHECK: %[[VAL_29:.*]] = sparse_tensor.storage_specifier.set %[[VAL_26]] ptr_mem_sz at 1 with %[[VAL_28]] : !sparse_tensor.storage_specifier +// CHECK: %[[VAL_30:.*]] = memref.alloc() : memref<4xf64> +// CHECK: %[[VAL_31:.*]] = m +// CHECK: %[[VAL_32:.*]] = memref.alloc() : memref<4xindex> +// CHECK: %[[VAL_33:.*]] = memref.cast %[[VAL_32]] : memref<4xindex> to memref +// CHECK: linalg.fill ins(%[[VAL_9]] : f64) outs(%[[VAL_30]] : memref<4xf64>) +// CHECK: linalg.fill ins(%[[VAL_12]] : i1) outs(%[[VAL_31]] : memref<4xi1>) +// CHECK: %[[VAL_34:.*]]:4 = scf.for %[[VAL_35:.*]] = %[[VAL_10]] to %[[VAL_8]] step %[[VAL_11]] iter_args(%[[VAL_36:.*]] = %[[VAL_27]], %[[VAL_37:.*]] = %[[VAL_17]], %[[VAL_38:.*]] = %[[VAL_19]], %[[VAL_39:.*]] = %[[VAL_29]]) -> (memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: %[[VAL_40:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_35]]] : memref +// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_35]], %[[VAL_11]] : index +// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_41]]] : memref +// CHECK: %[[VAL_43:.*]] = scf.for %[[VAL_44:.*]] = %[[VAL_40]] to %[[VAL_42]] step %[[VAL_11]] iter_args(%[[VAL_45:.*]] = %[[VAL_10]]) -> (index) { +// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_44]]] : memref +// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_44]]] : memref +// CHECK: %[[VAL_48:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_46]]] : memref +// CHECK: %[[VAL_49:.*]] = arith.addi %[[VAL_46]], %[[VAL_11]] : index +// CHECK: %[[VAL_50:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_49]]] : memref +// CHECK: %[[VAL_51:.*]] = scf.for %[[VAL_52:.*]] = %[[VAL_48]] to %[[VAL_50]] step %[[VAL_11]] iter_args(%[[VAL_53:.*]] = %[[VAL_45]]) -> (index) { +// CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_52]]] : memref +// CHECK: %[[VAL_55:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_54]]] : memref<4xf64> +// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_52]]] : memref +// CHECK: %[[VAL_57:.*]] = arith.mulf %[[VAL_47]], %[[VAL_56]] : f64 +// CHECK: %[[VAL_58:.*]] = arith.addf %[[VAL_55]], %[[VAL_57]] : f64 +// CHECK: %[[VAL_59:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_54]]] : memref<4xi1> +// CHECK: %[[VAL_60:.*]] = arith.cmpi eq, %[[VAL_59]], %[[VAL_12]] : i1 +// CHECK: %[[VAL_61:.*]] = scf.if %[[VAL_60]] -> (index) { +// CHECK: memref.store %[[VAL_13]], %[[VAL_31]]{{\[}}%[[VAL_54]]] : memref<4xi1> +// CHECK: memref.store %[[VAL_54]], %[[VAL_32]]{{\[}}%[[VAL_53]]] : memref<4xindex> +// CHECK: %[[VAL_62:.*]] = arith.addi %[[VAL_53]], %[[VAL_11]] : index +// CHECK: scf.yield %[[VAL_62]] : index // CHECK: } else { -// CHECK: scf.yield %[[VAL_59]] : index +// CHECK: scf.yield %[[VAL_53]] : index // CHECK: } -// CHECK: memref.store %[[VAL_64]], %[[VAL_36]]{{\[}}%[[VAL_60]]] : memref<4xf64> -// CHECK: scf.yield %[[VAL_69:.*]] : index +// CHECK: memref.store %[[VAL_58]], %[[VAL_30]]{{\[}}%[[VAL_54]]] : memref<4xf64> +// CHECK: scf.yield %[[VAL_63:.*]] : index // CHECK: } {"Emitted from" = "linalg.generic"} -// CHECK: scf.yield %[[VAL_70:.*]] : index +// CHECK: scf.yield %[[VAL_64:.*]] : index // CHECK: } {"Emitted from" = "linalg.generic"} -// CHECK: sparse_tensor.sort hybrid_quick_sort %[[VAL_71:.*]], %[[VAL_39]] : memref -// CHECK: %[[VAL_72:.*]]:4 = scf.for %[[VAL_73:.*]] = %[[VAL_11]] to %[[VAL_71]] step %[[VAL_12]] iter_args(%[[VAL_74:.*]] = %[[VAL_42]], %[[VAL_75:.*]] = %[[VAL_43]], %[[VAL_76:.*]] = %[[VAL_44]], %[[VAL_77:.*]] = %[[VAL_45]]) -> (memref, memref, memref, !sparse_tensor.storage_specifier -// CHECK: %[[VAL_78:.*]] = memref.load %[[VAL_38]]{{\[}}%[[VAL_73]]] : memref<4xindex> -// CHECK: %[[VAL_79:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_78]]] : memref<4xf64> -// CHECK: %[[VAL_80:.*]]:4 = func.call @_insert_dense_compressed_4_4_f64_0_0(%[[VAL_74]], %[[VAL_75]], %[[VAL_76]], %[[VAL_77]], %[[VAL_41]], %[[VAL_78]], %[[VAL_79]]) : (memref, memref, memref, !sparse_tensor.storage_specifier -// CHECK: memref.store %[[VAL_10]], %[[VAL_36]]{{\[}}%[[VAL_78]]] : memref<4xf64> -// CHECK: memref.store %[[VAL_13]], %[[VAL_37]]{{\[}}%[[VAL_78]]] : memref<4xi1> -// CHECK: scf.yield %[[VAL_80]]#0, %[[VAL_80]]#1, %[[VAL_80]]#2, %[[VAL_80]]#3 : memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: sparse_tensor.sort hybrid_quick_sort %[[VAL_65:.*]], %[[VAL_33]] : memref +// CHECK: %[[VAL_66:.*]]:4 = scf.for %[[VAL_67:.*]] = %[[VAL_10]] to %[[VAL_65]] step %[[VAL_11]] iter_args(%[[VAL_68:.*]] = %[[VAL_36]], %[[VAL_69:.*]] = %[[VAL_37]], %[[VAL_70:.*]] = %[[VAL_38]], %[[VAL_71:.*]] = %[[VAL_39]]) -> (memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: %[[VAL_72:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_67]]] : memref<4xindex> +// CHECK: %[[VAL_73:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_72]]] : memref<4xf64> +// CHECK: %[[VAL_74:.*]]:4 = func.call @_insert_dense_compressed_4_4_f64_0_0(%[[VAL_68]], %[[VAL_69]], %[[VAL_70]], %[[VAL_71]], %[[VAL_35]], %[[VAL_72]], %[[VAL_73]]) : (memref, memref, memref, !sparse_tensor.storage_specifie +// CHECK: memref.store %[[VAL_9]], %[[VAL_30]]{{\[}}%[[VAL_72]]] : memref<4xf64> +// CHECK: memref.store %[[VAL_12]], %[[VAL_31]]{{\[}}%[[VAL_72]]] : memref<4xi1> +// CHECK: scf.yield %[[VAL_74]]#0, %[[VAL_74]]#1, %[[VAL_74]]#2, %[[VAL_74]]#3 : memref, memref, memref, !sparse_tensor.storage_specifier // CHECK: } -// CHECK: scf.yield %[[VAL_81:.*]]#0, %[[VAL_81]]#1, %[[VAL_81]]#2, %[[VAL_81]]#3 : memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: scf.yield %[[VAL_75:.*]]#0, %[[VAL_75]]#1, %[[VAL_75]]#2, %[[VAL_75]]#3 : memref, memref, memref, !sparse_tensor.storage_specifier // CHECK: } {"Emitted from" = "linalg.generic"} -// CHECK: memref.dealloc %[[VAL_36]] : memref<4xf64> -// CHECK: memref.dealloc %[[VAL_37]] : memref<4xi1> -// CHECK: memref.dealloc %[[VAL_38]] : memref<4xindex> -// CHECK: %[[VAL_82:.*]] = sparse_tensor.storage_specifier.get %[[VAL_83:.*]]#3 ptr_mem_sz at 1 : !sparse_tensor.storage_specifier -// CHECK: %[[VAL_84:.*]] = arith.index_cast %[[VAL_82]] : i64 to index -// CHECK: %[[VAL_85:.*]] = memref.load %[[VAL_83]]#0{{\[}}%[[VAL_11]]] : memref -// CHECK: %[[VAL_86:.*]] = scf.for %[[VAL_87:.*]] = %[[VAL_12]] to %[[VAL_84]] step %[[VAL_12]] iter_args(%[[VAL_88:.*]] = %[[VAL_85]]) -> (index) { -// CHECK: %[[VAL_89:.*]] = memref.load %[[VAL_83]]#0{{\[}}%[[VAL_87]]] : memref -// CHECK: %[[VAL_90:.*]] = arith.cmpi eq, %[[VAL_89]], %[[VAL_11]] : index -// CHECK: %[[VAL_91:.*]] = arith.select %[[VAL_90]], %[[VAL_88]], %[[VAL_89]] : index -// CHECK: scf.if %[[VAL_90]] { -// CHECK: memref.store %[[VAL_88]], %[[VAL_83]]#0{{\[}}%[[VAL_87]]] : memref +// CHECK: memref.dealloc %[[VAL_30]] : memref<4xf64> +// CHECK: memref.dealloc %[[VAL_31]] : memref<4xi1> +// CHECK: memref.dealloc %[[VAL_32]] : memref<4xindex> +// CHECK: %[[VAL_76:.*]] = sparse_tensor.storage_specifier.get %[[VAL_77:.*]]#3 ptr_mem_sz at 1 : !sparse_tensor.storage_specifier +// CHECK: %[[VAL_78:.*]] = memref.load %[[VAL_77]]#0{{\[}}%[[VAL_10]]] : memref +// CHECK: %[[VAL_79:.*]] = scf.for %[[VAL_80:.*]] = %[[VAL_11]] to %[[VAL_76]] step %[[VAL_11]] iter_args(%[[VAL_81:.*]] = %[[VAL_78]]) -> (index) { +// CHECK: %[[VAL_82:.*]] = memref.load %[[VAL_77]]#0{{\[}}%[[VAL_80]]] : memref +// CHECK: %[[VAL_83:.*]] = arith.cmpi eq, %[[VAL_82]], %[[VAL_10]] : index +// CHECK: %[[VAL_84:.*]] = arith.select %[[VAL_83]], %[[VAL_81]], %[[VAL_82]] : index +// CHECK: scf.if %[[VAL_83]] { +// CHECK: memref.store %[[VAL_81]], %[[VAL_77]]#0{{\[}}%[[VAL_80]]] : memref // CHECK: } -// CHECK: scf.yield %[[VAL_91]] : index +// CHECK: scf.yield %[[VAL_84]] : index // CHECK: } -// CHECK: return %[[VAL_83]]#0, %[[VAL_83]]#1, %[[VAL_83]]#2, %[[VAL_83]]#3 : memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: return %[[VAL_77]]#0, %[[VAL_77]]#1, %[[VAL_77]]#2, %[[VAL_77]]#3 : memref, memref, memref, !sparse_tensor.storage_specifier func.func @matmul(%A: tensor<4x8xf64, #CSR>, %B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> { %C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR> diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir @@ -19,16 +19,13 @@ // CHECK: %[[VAL_10:.*]] = sparse_tensor.storage_specifier.init : // CHECK: %[[VAL_11:.*]] = arith.constant 6 : index // CHECK: %[[VAL_12:.*]] = arith.constant 100 : index -// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : index to i32 -// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]] dim_sz at 0 with %[[VAL_13]] : i32, +// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]] dim_sz at 0 with %[[VAL_12]] // CHECK: %[[VAL_15:.*]] = arith.constant 2 : index -// CHECK: %[[VAL_16:.*]] = arith.index_cast %[[VAL_15]] : index to i32 -// CHECK: %[[VAL_17:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]] ptr_mem_sz at 0 with %[[VAL_16]] : i32, -// CHECK: %[[VAL_18:.*]] = arith.index_cast %[[VAL_11]] : index to i32 -// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_17]] idx_mem_sz at 0 with %[[VAL_18]] : i32, -// CHECK: %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]] dim_sz at 1 with %[[VAL_13]] : i32, -// CHECK: %[[VAL_21:.*]] = sparse_tensor.storage_specifier.set %[[VAL_20]] idx_mem_sz at 1 with %[[VAL_18]] : i32, -// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_21]] val_mem_sz with %[[VAL_18]] : i32, +// CHECK: %[[VAL_17:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]] ptr_mem_sz at 0 with %[[VAL_15]] +// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_17]] idx_mem_sz at 0 with %[[VAL_11]] +// CHECK: %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]] dim_sz at 1 with %[[VAL_12]] +// CHECK: %[[VAL_21:.*]] = sparse_tensor.storage_specifier.set %[[VAL_20]] idx_mem_sz at 1 with %[[VAL_11]] +// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_21]] val_mem_sz with %[[VAL_11]] // CHECK: return %[[VAL_4]], %[[VAL_7]], %[[VAL_9]], %[[VAL_22]] : memref, memref, memref, // CHECK: } func.func @sparse_pack(%data: tensor<6xf64>, %index: tensor<6x2xi32>) @@ -68,8 +65,7 @@ // CHECK: %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_20:.*]] : memref<6xf64> // CHECK: %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<6x2xi32> // CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier -// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_22]] : i32 to index -// CHECK: return %[[VAL_19]], %[[VAL_21]], %[[VAL_23]] : tensor<6xf64>, tensor<6x2xi32>, index +// CHECK: return %[[VAL_19]], %[[VAL_21]], %[[VAL_22]] : tensor<6xf64>, tensor<6x2xi32>, index // CHECK: } func.func @sparse_unpack(%sp: tensor<100x100xf64, #COO>) -> (tensor<6xf64>, tensor<6x2xi32>, index) { %d, %i, %nnz = sparse_tensor.unpack %sp : tensor<100x100xf64, #COO> diff --git a/mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir b/mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir --- a/mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir +++ b/mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir @@ -16,23 +16,25 @@ } // CHECK-LABEL: func.func @sparse_get_md( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>) -> i64 { +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>) -> index { // CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> -// CHECK: return %[[VAL_1]] : i64 -func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#CSR>) -> i64 { +// CHECK: %[[CAST:.*]] = arith.index_cast %[[VAL_1]] : i64 to index +// CHECK: return %[[CAST]] : index +func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#CSR>) -> index { %0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 0 - : !sparse_tensor.storage_specifier<#CSR> to i64 - return %0 : i64 + : !sparse_tensor.storage_specifier<#CSR> + return %0 : index } // CHECK-LABEL: func.func @sparse_set_md( // CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>, -// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.struct<(array<2 x i64>, array<3 x i64>)> { -// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK-SAME: %[[VAL_1:.*]]: index) -> !llvm.struct<(array<2 x i64>, array<3 x i64>)> { +// CHECK: %[[CAST:.*]] = arith.index_cast %[[VAL_1]] : index to i64 +// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[CAST]], %[[VAL_0]][0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> // CHECK: return %[[VAL_2]] : !llvm.struct<(array<2 x i64>, array<3 x i64>)> -func.func @sparse_set_md(%arg0: !sparse_tensor.storage_specifier<#CSR>, %arg1: i64) +func.func @sparse_set_md(%arg0: !sparse_tensor.storage_specifier<#CSR>, %arg1: index) -> !sparse_tensor.storage_specifier<#CSR> { %0 = sparse_tensor.storage_specifier.set %arg0 dim_sz at 0 with %arg1 - : i64, !sparse_tensor.storage_specifier<#CSR> + : !sparse_tensor.storage_specifier<#CSR> return %0 : !sparse_tensor.storage_specifier<#CSR> }