diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -578,7 +578,8 @@ Value targetLen = constantIndex(builder, loc, len); Value bufferLen = linalg::createOrFoldDimOp(builder, loc, buffer, 0); - Value reallocP = builder.create(loc, arith::CmpIPredicate::ult, + // Reallocates if target length is greater than the actual buffer len. + Value reallocP = builder.create(loc, arith::CmpIPredicate::ugt, targetLen, bufferLen); scf::IfOp ifOp = builder.create(loc, retTp, reallocP, true); // If targetLen > bufferLen, reallocate to get enough sparse to return. diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir @@ -43,7 +43,7 @@ // CHECK: %[[VAL_4:.*]] = arith.constant 6 : index // CHECK: %[[VAL_5:.*]] = arith.constant 0 : index // CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_2]], %[[VAL_5]] : memref -// CHECK: %[[VAL_7:.*]] = arith.cmpi ult, %[[VAL_4]], %[[VAL_6]] : index +// CHECK: %[[VAL_7:.*]] = arith.cmpi ugt, %[[VAL_4]], %[[VAL_6]] : index // CHECK: %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (memref<6xf64>) { // CHECK: %[[VAL_9:.*]] = memref.realloc %[[VAL_2]] : memref to memref<6xf64> // CHECK: scf.yield %[[VAL_9]] : memref<6xf64> @@ -53,7 +53,7 @@ // CHECK: } // CHECK: %[[VAL_11:.*]] = arith.constant 12 : index // CHECK: %[[VAL_12:.*]] = memref.dim %[[VAL_1]], %[[VAL_5]] : memref -// CHECK: %[[VAL_13:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_12]] : index +// CHECK: %[[VAL_13:.*]] = arith.cmpi ugt, %[[VAL_11]], %[[VAL_12]] : index // CHECK: %[[VAL_14:.*]] = scf.if %[[VAL_13]] -> (memref<12xi32>) { // CHECK: %[[VAL_15:.*]] = memref.realloc %[[VAL_1]] : memref to memref<12xi32> // CHECK: scf.yield %[[VAL_15]] : memref<12xi32>