diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -19,6 +19,7 @@ MLIRSCF MLIRStandard MLIRSparseTensor + MLIRTensor MLIRTransforms MLIRVector ) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; @@ -103,15 +104,21 @@ return failure(); // User pointer. params.push_back(operands[0]); - // Sparsity annotations. + // Sparsity annotations in tensor constant form. Note that we cast + // the static shape into a dynamic shape to ensure that the method + // signature remains uniform accross different tensor dimensions. SmallVector attrs; unsigned sz = enc.getDimLevelType().size(); for (unsigned i = 0; i < sz; i++) attrs.push_back(enc.getDimLevelType()[i] == SparseTensorEncodingAttr::DimLevelType::Compressed); - auto elts = DenseElementsAttr::get( - RankedTensorType::get({sz}, rewriter.getIntegerType(1)), attrs); - params.push_back(rewriter.create(loc, elts)); + Type etp = rewriter.getIntegerType(1); + RankedTensorType tt1 = RankedTensorType::get({sz}, etp); + RankedTensorType tt2 = + RankedTensorType::get({ShapedType::kDynamicSize}, etp); + auto elts = + rewriter.create(loc, DenseElementsAttr::get(tt1, attrs)); + params.push_back(rewriter.create(loc, tt2, elts)); // Seconary and primary types encoding. unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -120,6 +120,7 @@ target.addDynamicallyLegalOp( [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); target.addLegalOp(); + target.addLegalOp(); populateFuncOpTypeConversionPattern(patterns, converter); populateCallOpTypeConversionPattern(patterns, converter); populateSparseTensorConversionPatterns(converter, patterns); diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -16,6 +16,10 @@ indexBitWidth = 32 }> +#SparseMatrix = #sparse_tensor.encoding<{ + dimLevelType = ["dense", "compressed"] +}> + // CHECK-LABEL: func @sparse_dim( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) // CHECK: %[[C:.*]] = constant 0 : index @@ -27,15 +31,28 @@ return %0 : index } -// CHECK-LABEL: func @sparse_new( +// CHECK-LABEL: func @sparse_new1d( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]] +// CHECK: %[[D:.*]] = constant dense : tensor<1xi1> +// CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<1xi1> to tensor +// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, tensor, i64, i64, i64) -> !llvm.ptr // CHECK: return %[[T]] : !llvm.ptr -func @sparse_new(%arg0: !llvm.ptr) -> tensor<128xf64, #SparseVector> { +func @sparse_new1d(%arg0: !llvm.ptr) -> tensor<128xf64, #SparseVector> { %0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor<128xf64, #SparseVector> return %0 : tensor<128xf64, #SparseVector> } +// CHECK-LABEL: func @sparse_new2d( +// CHECK-SAME: %[[A:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK: %[[D:.*]] = constant dense<[false, true]> : tensor<2xi1> +// CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<2xi1> to tensor +// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, tensor, i64, i64, i64) -> !llvm.ptr +// CHECK: return %[[T]] : !llvm.ptr +func @sparse_new2d(%arg0: !llvm.ptr) -> tensor { + %0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor + return %0 : tensor +} + // CHECK-LABEL: func @sparse_pointers( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) // CHECK: %[[C:.*]] = constant 0 : index