diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -45,6 +45,9 @@ /// Returns true iff the given type is a type for a COO tensor with the last /// dimension level type being unique. bool isUniqueCOOType(RankedTensorType tp); +/// Returns true iff the given sparse tensor encoding attribute has a tailing +/// COO region starting at the given dimension. +bool isUniqueCOOType(SparseTensorEncodingAttr enc, uint64_t s); // // Dimension level types. diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -266,20 +266,25 @@ bool mlir::sparse_tensor::isUniqueCOOType(RankedTensorType tp) { SparseTensorEncodingAttr enc = getSparseTensorEncoding(tp); - if (!enc) return false; - if (!isCompressedDim(tp, 0)) + return isUniqueCOOType(enc, 0); +} + +bool mlir::sparse_tensor::isUniqueCOOType(SparseTensorEncodingAttr enc, + uint64_t s) { + if (!isCompressedDim(enc, s)) return false; - for (uint64_t i = 1, e = tp.getRank(); i < e; ++i) - if (!isSingletonDim(tp, i)) + uint64_t rank = enc.getDimLevelType().size(); + for (uint64_t i = s + 1; i < rank; ++i) + if (!isSingletonDim(enc, i)) return false; // This works for rank == 1 (unique the only compressed) and rank > 1 (unique // on the last singleton). - return isUniqueDim(tp, tp.getRank() - 1); + return isUniqueDLT(getDimLevelType(enc, rank - 1)); } uint64_t mlir::sparse_tensor::toOrigDim(SparseTensorEncodingAttr enc, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -183,6 +183,16 @@ void sizesFromSrc(OpBuilder &builder, SmallVectorImpl &sizes, Location loc, Value src); +/// Generates a 1D MemRefType with a dynamic size. When withLayout is set, add +/// a layout with unknown strides and offset to the type. +inline MemRefType get1DMemRefType(Type etp, bool withLayout) { + auto layout = withLayout ? StridedLayoutAttr::StridedLayoutAttr::get( + etp.getContext(), ShapedType::kDynamic, + {ShapedType::kDynamic}) + : StridedLayoutAttr(); + return MemRefType::get(ShapedType::kDynamic, etp, layout); +} + /// Scans to top of generated loop. Operation *getTop(Operation *op); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -8,6 +8,7 @@ #include "LoopEmitter.h" #include "CodegenUtils.h" +#include "SparseTensorStorageLayout.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -124,25 +125,25 @@ auto rank = rtp.getRank(); auto shape = rtp.getShape(); auto enc = getSparseTensorEncoding(rtp); - auto dynShape = {ShapedType::kDynamic}; + uint64_t cooStart = enc ? getCooStart(enc) : rank; // Scan all dimensions of current tensor. for (int64_t d = 0; d < rank; d++) { // This should be called only once at beginning. assert(!ptrBuffer[t][d] && !idxBuffer[t][d] && !highs[t][d]); // Handle sparse storage schemes. if (isCompressedDLT(dimTypes[t][d])) { - auto ptrTp = - MemRefType::get(dynShape, getPointerOverheadType(builder, enc)); - auto indTp = - MemRefType::get(dynShape, getIndexOverheadType(builder, enc)); + auto ptrTp = get1DMemRefType(getPointerOverheadType(builder, enc), + /*withLayout=*/false); + auto indTp = get1DMemRefType(getIndexOverheadType(builder, enc), + /*withLayout=*/d >= cooStart); auto dim = builder.getIndexAttr(d); // Generate sparse primitives to obtains pointer and indices. ptrBuffer[t][d] = builder.create(loc, ptrTp, tensor, dim); idxBuffer[t][d] = builder.create(loc, indTp, tensor, dim); } else if (isSingletonDLT(dimTypes[t][d])) { // Singleton dimension, fetch indices. - auto indTp = - MemRefType::get(dynShape, getIndexOverheadType(builder, enc)); + auto indTp = get1DMemRefType(getIndexOverheadType(builder, enc), + /*withLayout=*/d >= cooStart); auto dim = builder.getIndexAttr(d); idxBuffer[t][d] = builder.create(loc, indTp, tensor, dim); } else { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -911,15 +911,29 @@ }; /// Sparse codegen rule for index accesses. -class SparseToIndicesConverter - : public SparseGetterOpConverter { +class SparseToIndicesConverter : public OpConversionPattern { public: - using SparseGetterOpConverter::SparseGetterOpConverter; - // Callback for SparseGetterOpConverter. - static Value getFieldForOp(const SparseTensorDescriptor &desc, - ToIndicesOp op) { + using OpAdaptor = typename ToIndicesOp::Adaptor; + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Replace the requested pointer access with corresponding field. + // The cast_op is inserted by type converter to intermix 1:N type + // conversion. + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); uint64_t dim = op.getDimension().getZExtValue(); - return desc.getIdxMemRef(dim); + Value field = desc.getIdxMemRef(dim); + + // Insert a cast to bridge the actual type to the user expected type. If the + // actual type and the user expected type aren't compatible, the compiler or + // the runtime will issue an error. + Type resType = op.getResult().getType(); + if (resType != field.getType()) + field = rewriter.create(op.getLoc(), resType, field); + rewriter.replaceOp(op, field); + + return success(); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -1122,10 +1122,24 @@ Type resType = op.getType(); Type indType = resType.cast().getElementType(); SmallString<15> name{"sparseIndices", overheadTypeFunctionSuffix(indType)}; - Value dim = - constantIndex(rewriter, op->getLoc(), op.getDimension().getZExtValue()); - replaceOpWithFuncCall(rewriter, op, name, resType, - {adaptor.getTensor(), dim}, EmitCInterface::On); + Location loc = op->getLoc(); + Value dim = constantIndex(rewriter, loc, op.getDimension().getZExtValue()); + + // The function returns a MemRef without a layout. + MemRefType callRetType = get1DMemRefType(indType, false); + SmallVector operands{adaptor.getTensor(), dim}; + auto fn = getFunc(op->getParentOfType(), name, callRetType, + operands, EmitCInterface::On); + Value callRet = + rewriter.create(loc, callRetType, fn, operands) + .getResult(0); + + // Cast the MemRef type to the type expected by the users, though these + // two types should be compatible at runtime. + if (resType != callRetType) + callRet = rewriter.create(loc, resType, callRet); + rewriter.replaceOp(op, callRet); + return success(); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -12,6 +12,7 @@ #include "CodegenUtils.h" #include "LoopEmitter.h" +#include "SparseTensorStorageLayout.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -809,14 +810,14 @@ // Sort the COO tensor so that its elements are ordered via increasing // indices for the storage ordering of the dst tensor. SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp); - auto dynShape = {ShapedType::kDynamic}; - auto indTp = - MemRefType::get(dynShape, getIndexOverheadType(rewriter, encSrc)); uint64_t rank = dstTp.getRank(); + uint64_t cooStart = getCooStart(encSrc); // Gather the indices-arrays in the dst tensor storage order. SmallVector xs(rank, Value()); for (uint64_t i = 0; i < rank; i++) { uint64_t orgDim = toOrigDim(encSrc, i); + MemRefType indTp = get1DMemRefType( + getIndexOverheadType(rewriter, encSrc), i >= cooStart); xs[toStoredDim(encDst, orgDim)] = rewriter.create( loc, indTp, src, rewriter.getIndexAttr(i)); } @@ -827,7 +828,7 @@ nnz); // Retrieve the values-array. - auto valTp = MemRefType::get(dynShape, srcTp.getElementType()); + MemRefType valTp = get1DMemRefType(srcTp.getElementType(), false); Value y = rewriter.create(loc, valTp, src); // Sort the COO tensor. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h @@ -64,6 +64,22 @@ static_assert(static_cast(SparseTensorFieldKind::ValMemRef) == static_cast(StorageSpecifierKind::ValMemSize)); +/// Returns the starting dimension for a tailing COO region that spans across +/// at least two dimensions. If no such COO region is found, returns the rank +/// of the tensor. +inline unsigned getCooStart(SparseTensorEncodingAttr enc) { + unsigned rank = enc.getDimLevelType().size(); + if (rank <= 1) + return rank; + + for (unsigned r = 0; r < rank - 1; r++) { + if (isUniqueCOOType(enc, r)) + return r; + } + + return rank; +} + /// For each field that will be allocated for the given sparse tensor encoding, /// calls the callback with the corresponding field index, field kind, dimension /// (for sparse tensor level memrefs) and dimlevelType. diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir --- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir @@ -126,7 +126,7 @@ // CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index} // CHECK-RWT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]] // CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[COO]] -// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V]] : memref, memref jointly memref +// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V]] // CHECK-RWT: %[[T3:.*]] = bufferization.alloc_tensor() // CHECK-RWT: %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]]) // CHECK-RWT: ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f64, %[[L1T:.*]]: tensor @@ -186,7 +186,7 @@ // CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index} // CHECK-RWT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]] // CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[COO]] -// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V]] : memref, memref jointly memref +// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V]] // CHECK-RWT: %[[T3:.*]] = bufferization.alloc_tensor() // CHECK-RWT: %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]]) // CHECK-RWT: ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f32, %[[L1T:.*]]: tensor diff --git a/mlir/test/Dialect/SparseTensor/sorted_coo.mlir b/mlir/test/Dialect/SparseTensor/sorted_coo.mlir --- a/mlir/test/Dialect/SparseTensor/sorted_coo.mlir +++ b/mlir/test/Dialect/SparseTensor/sorted_coo.mlir @@ -71,17 +71,17 @@ // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index // CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref -// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref -// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref +// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref> to memref> to memref // CHECK-DAG: %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<64xf64> // CHECK-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64> // CHECK-DAG: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref // CHECK-DAG: %[[VAL_12:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref // CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_11]] to %[[VAL_12]] step %[[VAL_4]] { -// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref +// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref -// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref +// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref // CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref<64xf64> // CHECK: %[[VAL_19:.*]] = arith.mulf %[[VAL_17]], %[[VAL_18]] : f64 @@ -113,12 +113,12 @@ // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index // CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index // CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref -// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref -// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref +// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref> to memref> to memref // CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 0 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref -// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 0 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref -// CHECK-DAG: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 1 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref +// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 0 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref> to memref> to memref // CHECK-DAG: %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x64xf64> // CHECK-DAG: linalg.fill ins(%[[VAL_3]] : f64) outs(%[[VAL_14]] : memref<32x64xf64>) @@ -133,8 +133,8 @@ // CHECK: scf.condition(%[[VAL_24]]) %[[VAL_20]], %[[VAL_21]] : index, index // CHECK: } do { // CHECK: ^bb0(%[[VAL_25:.*]]: index, %[[VAL_26:.*]]: index): -// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_25]]] : memref -// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_26]]] : memref +// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_25]]] : memref -// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_43]]] : memref +// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_42]]] : memref