diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp @@ -538,15 +538,8 @@ // Find lower and upper bound in current dimension. Value up; if (shape[d] == TensorType::kDynamicSize) { - // For the output tensor, we may need to infer the upper bound. - // For all others, we look at the incoming argument. - if (t == numInputs && !op.getNumInitTensors()) { - up = codegen.sizes[i]; - assert(up); // TODO: what else? - } else { - Value arg = t < numInputs ? op.getInput(t) : op.getInitTensors()[0]; - up = rewriter.create(loc, arg, d); - } + Value arg = t < numInputs ? op.getInput(t) : op.getOutput(0); + up = rewriter.create(loc, arg, d); args.push_back(up); } else { up = rewriter.create(loc, shape[d]); diff --git a/mlir/test/Dialect/Linalg/sparse_2d.mlir b/mlir/test/Dialect/Linalg/sparse_2d.mlir --- a/mlir/test/Dialect/Linalg/sparse_2d.mlir +++ b/mlir/test/Dialect/Linalg/sparse_2d.mlir @@ -1139,19 +1139,19 @@ // CHECK: %[[VAL_2:.*]] = constant 999 : index // CHECK: %[[VAL_3:.*]] = constant 0 : index // CHECK: %[[VAL_4:.*]] = constant 1 : index -// CHECK: %[[VAL_5:.*]] = dim %[[VAL_0]], %[[VAL_3]] : tensor +// CHECK: %[[VAL_5:.*]] = alloca(%[[VAL_2]]) : memref // CHECK: %[[VAL_6:.*]] = alloca(%[[VAL_2]]) : memref -// CHECK: %[[VAL_7:.*]] = alloca(%[[VAL_2]]) : memref -// CHECK: %[[VAL_8:.*]] = dim %[[VAL_0]], %[[VAL_4]] : tensor -// CHECK: %[[VAL_9:.*]] = alloca(%[[VAL_2]]) : memref -// CHECK: %[[VAL_10:.*]] = alloca(%[[VAL_5]], %[[VAL_8]]) : memref -// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] { -// CHECK: %[[VAL_12:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[VAL_7:.*]] = alloca(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = dim %[[VAL_0]], %[[VAL_3]] : tensor +// CHECK: %[[VAL_9:.*]] = dim %[[VAL_0]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_10:.*]] = alloca(%[[VAL_8]], %[[VAL_9]]) : memref +// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_4]] { +// CHECK: %[[VAL_12:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_11]]] : memref // CHECK: %[[VAL_13:.*]] = addi %[[VAL_11]], %[[VAL_4]] : index -// CHECK: %[[VAL_14:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref +// CHECK: %[[VAL_14:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_13]]] : memref // CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_4]] { -// CHECK: %[[VAL_16:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref -// CHECK: %[[VAL_17:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref +// CHECK: %[[VAL_16:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref +// CHECK: %[[VAL_17:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref // CHECK: %[[VAL_18:.*]] = mulf %[[VAL_17]], %[[VAL_1]] : f64 // CHECK: store %[[VAL_18]], %[[VAL_10]]{{\[}}%[[VAL_11]], %[[VAL_16]]] : memref // CHECK: }