diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -46,6 +46,11 @@ /// dimension level type being unique. bool isUniqueCOOType(RankedTensorType tp); +/// Returns the starting dimension for a tailing COO region that spans across +/// at least two dimensions. If no such COO region is found, returns the rank +/// of the tensor. +unsigned getCOOStart(SparseTensorEncodingAttr enc); + // // Dimension level types. // diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -264,22 +264,42 @@ return nullptr; } +/// Returns true iff the given sparse tensor encoding attribute has a tailing +/// COO region starting at the given dimension. +static bool isCOOType(SparseTensorEncodingAttr enc, uint64_t s, bool isUnique) { + uint64_t rank = enc.getDimLevelType().size(); + assert(s < rank && "Dimension out of bounds"); + if (!isCompressedDim(enc, s)) + return false; + + for (uint64_t i = s + 1; i < rank; ++i) + if (!isSingletonDim(enc, i)) + return false; + + // This works for rank == 1 (unique the only compressed) and rank > 1 (unique + // on the last singleton). + return !isUnique || isUniqueDLT(getDimLevelType(enc, rank - 1)); +} + bool mlir::sparse_tensor::isUniqueCOOType(RankedTensorType tp) { SparseTensorEncodingAttr enc = getSparseTensorEncoding(tp); - if (!enc) return false; - if (!isCompressedDim(tp, 0)) - return false; + return isCOOType(enc, 0, /*isUnique=*/true); +} - for (uint64_t i = 1, e = tp.getRank(); i < e; ++i) - if (!isSingletonDim(tp, i)) - return false; +unsigned mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) { + unsigned rank = enc.getDimLevelType().size(); + if (rank <= 1) + return rank; - // This works for rank == 1 (unique the only compressed) and rank > 1 (unique - // on the last singleton). - return isUniqueDim(tp, tp.getRank() - 1); + for (unsigned r = 0; r < rank - 1; r++) { + if (isCOOType(enc, r, /*isUnique=*/false)) + return r; + } + + return rank; } uint64_t mlir::sparse_tensor::toOrigDim(SparseTensorEncodingAttr enc, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -183,6 +183,16 @@ void sizesFromSrc(OpBuilder &builder, SmallVectorImpl &sizes, Location loc, Value src); +/// Generates a 1D MemRefType with a dynamic size. When withLayout is set, add +/// a layout with unknown strides and offset to the type. +inline MemRefType get1DMemRefType(Type etp, bool withLayout) { + auto layout = withLayout ? StridedLayoutAttr::StridedLayoutAttr::get( + etp.getContext(), ShapedType::kDynamic, + {ShapedType::kDynamic}) + : StridedLayoutAttr(); + return MemRefType::get(ShapedType::kDynamic, etp, layout); +} + /// Scans to top of generated loop. Operation *getTop(Operation *op); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -124,25 +124,25 @@ auto rank = rtp.getRank(); auto shape = rtp.getShape(); auto enc = getSparseTensorEncoding(rtp); - auto dynShape = {ShapedType::kDynamic}; + uint64_t cooStart = enc ? getCOOStart(enc) : rank; // Scan all dimensions of current tensor. for (int64_t d = 0; d < rank; d++) { // This should be called only once at beginning. assert(!ptrBuffer[t][d] && !idxBuffer[t][d] && !highs[t][d]); // Handle sparse storage schemes. if (isCompressedDLT(dimTypes[t][d])) { - auto ptrTp = - MemRefType::get(dynShape, getPointerOverheadType(builder, enc)); - auto indTp = - MemRefType::get(dynShape, getIndexOverheadType(builder, enc)); + auto ptrTp = get1DMemRefType(getPointerOverheadType(builder, enc), + /*withLayout=*/false); + auto indTp = get1DMemRefType(getIndexOverheadType(builder, enc), + /*withLayout=*/d >= cooStart); auto dim = builder.getIndexAttr(d); // Generate sparse primitives to obtains pointer and indices. ptrBuffer[t][d] = builder.create(loc, ptrTp, tensor, dim); idxBuffer[t][d] = builder.create(loc, indTp, tensor, dim); } else if (isSingletonDLT(dimTypes[t][d])) { // Singleton dimension, fetch indices. - auto indTp = - MemRefType::get(dynShape, getIndexOverheadType(builder, enc)); + auto indTp = get1DMemRefType(getIndexOverheadType(builder, enc), + /*withLayout=*/d >= cooStart); auto dim = builder.getIndexAttr(d); idxBuffer[t][d] = builder.create(loc, indTp, tensor, dim); } else { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -878,60 +878,65 @@ } }; -/// Base class for getter-like operations, e.g., to_indices, to_pointers. -template -class SparseGetterOpConverter : public OpConversionPattern { +/// Sparse codegen rule for pointer accesses. +class SparseToPointersConverter : public OpConversionPattern { public: - using OpAdaptor = typename SourceOp::Adaptor; - using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename ToPointersOp::Adaptor; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(SourceOp op, OpAdaptor adaptor, + matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Replace the requested pointer access with corresponding field. // The cast_op is inserted by type converter to intermix 1:N type // conversion. auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); - Value field = Base::getFieldForOp(desc, op); - rewriter.replaceOp(op, field); - return success(); - } -}; - -/// Sparse codegen rule for pointer accesses. -class SparseToPointersConverter - : public SparseGetterOpConverter { -public: - using SparseGetterOpConverter::SparseGetterOpConverter; - // Callback for SparseGetterOpConverter. - static Value getFieldForOp(const SparseTensorDescriptor &desc, - ToPointersOp op) { uint64_t dim = op.getDimension().getZExtValue(); - return desc.getPtrMemRef(dim); + rewriter.replaceOp(op, desc.getPtrMemRef(dim)); + return success(); } }; /// Sparse codegen rule for index accesses. -class SparseToIndicesConverter - : public SparseGetterOpConverter { +class SparseToIndicesConverter : public OpConversionPattern { public: - using SparseGetterOpConverter::SparseGetterOpConverter; - // Callback for SparseGetterOpConverter. - static Value getFieldForOp(const SparseTensorDescriptor &desc, - ToIndicesOp op) { + using OpAdaptor = typename ToIndicesOp::Adaptor; + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Replace the requested pointer access with corresponding field. + // The cast_op is inserted by type converter to intermix 1:N type + // conversion. + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); uint64_t dim = op.getDimension().getZExtValue(); - return desc.getIdxMemRef(dim); + Value field = desc.getIdxMemRef(dim); + + // Insert a cast to bridge the actual type to the user expected type. If the + // actual type and the user expected type aren't compatible, the compiler or + // the runtime will issue an error. + Type resType = op.getResult().getType(); + if (resType != field.getType()) + field = rewriter.create(op.getLoc(), resType, field); + rewriter.replaceOp(op, field); + + return success(); } }; /// Sparse codegen rule for value accesses. -class SparseToValuesConverter - : public SparseGetterOpConverter { +class SparseToValuesConverter : public OpConversionPattern { public: - using SparseGetterOpConverter::SparseGetterOpConverter; - // Callback for SparseGetterOpConverter. - static Value getFieldForOp(const SparseTensorDescriptor &desc, - ToValuesOp /*op*/) { - return desc.getValMemRef(); + using OpAdaptor = typename ToValuesOp::Adaptor; + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Replace the requested pointer access with corresponding field. + // The cast_op is inserted by type converter to intermix 1:N type + // conversion. + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + rewriter.replaceOp(op, desc.getValMemRef()); + return success(); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -1122,10 +1122,24 @@ Type resType = op.getType(); Type indType = resType.cast().getElementType(); SmallString<15> name{"sparseIndices", overheadTypeFunctionSuffix(indType)}; - Value dim = - constantIndex(rewriter, op->getLoc(), op.getDimension().getZExtValue()); - replaceOpWithFuncCall(rewriter, op, name, resType, - {adaptor.getTensor(), dim}, EmitCInterface::On); + Location loc = op->getLoc(); + Value dim = constantIndex(rewriter, loc, op.getDimension().getZExtValue()); + + // The function returns a MemRef without a layout. + MemRefType callRetType = get1DMemRefType(indType, false); + SmallVector operands{adaptor.getTensor(), dim}; + auto fn = getFunc(op->getParentOfType(), name, callRetType, + operands, EmitCInterface::On); + Value callRet = + rewriter.create(loc, callRetType, fn, operands) + .getResult(0); + + // Cast the MemRef type to the type expected by the users, though these + // two types should be compatible at runtime. + if (resType != callRetType) + callRet = rewriter.create(loc, resType, callRet); + rewriter.replaceOp(op, callRet); + return success(); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -809,14 +809,14 @@ // Sort the COO tensor so that its elements are ordered via increasing // indices for the storage ordering of the dst tensor. SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp); - auto dynShape = {ShapedType::kDynamic}; - auto indTp = - MemRefType::get(dynShape, getIndexOverheadType(rewriter, encSrc)); uint64_t rank = dstTp.getRank(); + uint64_t cooStart = getCOOStart(encSrc); // Gather the indices-arrays in the dst tensor storage order. SmallVector xs(rank, Value()); for (uint64_t i = 0; i < rank; i++) { uint64_t orgDim = toOrigDim(encSrc, i); + MemRefType indTp = get1DMemRefType( + getIndexOverheadType(rewriter, encSrc), i >= cooStart); xs[toStoredDim(encDst, orgDim)] = rewriter.create( loc, indTp, src, rewriter.getIndexAttr(i)); } @@ -827,7 +827,7 @@ nnz); // Retrieve the values-array. - auto valTp = MemRefType::get(dynShape, srcTp.getElementType()); + MemRefType valTp = get1DMemRefType(srcTp.getElementType(), false); Value y = rewriter.create(loc, valTp, src); // Sort the COO tensor. diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir --- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir @@ -126,7 +126,7 @@ // CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index} // CHECK-RWT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]] // CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[COO]] -// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V]] : memref, memref jointly memref +// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V]] // CHECK-RWT: %[[T3:.*]] = bufferization.alloc_tensor() // CHECK-RWT: %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]]) // CHECK-RWT: ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f64, %[[L1T:.*]]: tensor @@ -186,7 +186,7 @@ // CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index} // CHECK-RWT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]] // CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[COO]] -// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V]] : memref, memref jointly memref +// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V]] // CHECK-RWT: %[[T3:.*]] = bufferization.alloc_tensor() // CHECK-RWT: %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]]) // CHECK-RWT: ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f32, %[[L1T:.*]]: tensor diff --git a/mlir/test/Dialect/SparseTensor/sorted_coo.mlir b/mlir/test/Dialect/SparseTensor/sorted_coo.mlir --- a/mlir/test/Dialect/SparseTensor/sorted_coo.mlir +++ b/mlir/test/Dialect/SparseTensor/sorted_coo.mlir @@ -71,17 +71,17 @@ // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index // CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref -// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref -// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref +// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref> to memref> to memref // CHECK-DAG: %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<64xf64> // CHECK-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64> // CHECK-DAG: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref // CHECK-DAG: %[[VAL_12:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref // CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_11]] to %[[VAL_12]] step %[[VAL_4]] { -// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref +// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref -// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref +// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref // CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref<64xf64> // CHECK: %[[VAL_19:.*]] = arith.mulf %[[VAL_17]], %[[VAL_18]] : f64 @@ -113,12 +113,12 @@ // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index // CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index // CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref -// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref -// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref +// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref> to memref> to memref // CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 0 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref -// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 0 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref -// CHECK-DAG: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 1 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref +// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 0 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ] }>> to memref> to memref> to memref // CHECK-DAG: %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x64xf64> // CHECK-DAG: linalg.fill ins(%[[VAL_3]] : f64) outs(%[[VAL_14]] : memref<32x64xf64>) @@ -133,8 +133,8 @@ // CHECK: scf.condition(%[[VAL_24]]) %[[VAL_20]], %[[VAL_21]] : index, index // CHECK: } do { // CHECK: ^bb0(%[[VAL_25:.*]]: index, %[[VAL_26:.*]]: index): -// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_25]]] : memref -// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_26]]] : memref +// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_25]]] : memref -// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_43]]] : memref +// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_42]]] : memref