diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp @@ -182,6 +182,7 @@ continue; // Conjunction already covered? for (unsigned p2 : latSets[s]) { + assert(!latGT(p1, p2)); // Lj => Li would be bad if (onlyDenseDiff(p2, p1)) { add = false; break; @@ -752,6 +753,17 @@ return val; } +/// Generates an address computation "sz * p + i". +static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, + Location loc, Value size, Value p, Value i) { + Value mul = rewriter.create(loc, size, p); + if (auto vtp = i.getType().dyn_cast()) { + Value inv = rewriter.create(loc, mul, vtp.getElementType()); + mul = genVectorInvariantValue(codegen, rewriter, inv); + } + return rewriter.create(loc, mul, i); +} + /// Recursively generates tensor expression. static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, unsigned exp) { @@ -1073,9 +1085,8 @@ break; Value p = (pat == 0) ? rewriter.create(loc, 0) : codegen.pidxs[tensor][topSort[pat - 1]]; - Value m = rewriter.create(loc, codegen.sizes[idx], p); - codegen.pidxs[tensor][idx] = - rewriter.create(loc, m, codegen.loops[idx]); + codegen.pidxs[tensor][idx] = genAddress( + codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); } } } diff --git a/mlir/test/Dialect/Linalg/sparse_vector.mlir b/mlir/test/Dialect/Linalg/sparse_vector.mlir --- a/mlir/test/Dialect/Linalg/sparse_vector.mlir +++ b/mlir/test/Dialect/Linalg/sparse_vector.mlir @@ -145,6 +145,40 @@ return %0 : tensor<1024xf32> } +// +// CHECK-VEC2-LABEL: func @mul_s_alt +// CHECK-VEC2-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-VEC2-DAG: %[[c1:.*]] = constant 1 : index +// CHECK-VEC2-DAG: %[[c16:.*]] = constant 16 : index +// CHECK-VEC2: %[[p:.*]] = load %{{.*}}[%[[c0]]] : memref +// CHECK-VEC2: %[[q:.*]] = index_cast %[[p]] : i32 to index +// CHECK-VEC2: %[[r:.*]] = load %{{.*}}[%[[c1]]] : memref +// CHECK-VEC2: %[[s:.*]] = index_cast %[[r]] : i32 to index +// CHECK-VEC2: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] { +// CHECK-VEC2: %[[sub:.*]] = subi %[[s]], %[[i]] : index +// CHECK-VEC2: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> +// CHECK-VEC2: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> +// CHECK-VEC2: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC2: %[[lb:.*]] = vector.gather %{{.*}}[%[[li]]], %[[mask]], %{{.*}} : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC2: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32> +// CHECK-VEC2: vector.scatter %{{.*}}[%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> +// CHECK-VEC2: } +// CHECK-VEC2: return +// +!SparseTensor = type !llvm.ptr +func @mul_s_alt(%argA: !SparseTensor, %argB: !SparseTensor, %argx: tensor<1024xf32>) -> tensor<1024xf32> { + %arga = linalg.sparse_tensor %argA : !SparseTensor to tensor<1024xf32> + %argb = linalg.sparse_tensor %argB : !SparseTensor to tensor<1024xf32> + %0 = linalg.generic #trait_mul_s + ins(%arga, %argb: tensor<1024xf32>, tensor<1024xf32>) + outs(%argx: tensor<1024xf32>) { + ^bb(%a: f32, %b: f32, %x: f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<1024xf32> + return %0 : tensor<1024xf32> +} + #trait_reduction_d = { indexing_maps = [ affine_map<(i) -> (i)>, // a