diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1098,13 +1098,7 @@ SparseVectorizationStrategy v, unsigned vl, SparseIntType pt, SparseIntType it, bool fo) : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl), - ptrType(pt), indType(it), fastOutput(fo) { - // TODO: remove restriction when vectors with index elements are supported - assert((v != SparseVectorizationStrategy::kAnyStorageInnerLoop || - (ptrType != SparseIntType::kNative && - indType != SparseIntType::kNative)) && - "This combination requires support for vectors with index elements"); - } + ptrType(pt), indType(it), fastOutput(fo) {} SparsificationOptions() : SparsificationOptions(SparseParallelizationStrategy::kNone, SparseVectorizationStrategy::kNone, 1u, diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1684,7 +1684,7 @@ Vector_Op<"gather">, Arguments<(ins Arg:$base, Variadic:$indices, - VectorOfRankAndType<[1], [AnyInteger]>:$index_vec, + VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec, VectorOfRankAndType<[1], [I1]>:$mask, VectorOfRank<[1]>:$pass_thru)>, Results<(outs VectorOfRank<[1]>:$result)> { @@ -1749,7 +1749,7 @@ Vector_Op<"scatter">, Arguments<(ins Arg:$base, Variadic:$indices, - VectorOfRankAndType<[1], [AnyInteger]>:$index_vec, + VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec, VectorOfRankAndType<[1], [I1]>:$mask, VectorOfRank<[1]>:$valueToStore)> { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp @@ -771,12 +771,14 @@ // extremely large offsets. Type etp = ptr.getType().cast().getElementType(); Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); - if (etp.getIntOrFloatBitWidth() < 32) - vload = rewriter.create( - loc, vload, vectorType(codegen, rewriter.getIntegerType(32))); - else if (etp.getIntOrFloatBitWidth() < 64) - vload = rewriter.create( - loc, vload, vectorType(codegen, rewriter.getIntegerType(64))); + if (!etp.isa()) { + if (etp.getIntOrFloatBitWidth() < 32) + vload = rewriter.create( + loc, vload, vectorType(codegen, rewriter.getIntegerType(32))); + else if (etp.getIntOrFloatBitWidth() < 64) + vload = rewriter.create( + loc, vload, vectorType(codegen, rewriter.getIntegerType(64))); + } return vload; } // For the scalar case, we simply zero extend narrower indices into 64-bit diff --git a/mlir/test/Dialect/Linalg/sparse_vector.mlir b/mlir/test/Dialect/Linalg/sparse_vector.mlir --- a/mlir/test/Dialect/Linalg/sparse_vector.mlir +++ b/mlir/test/Dialect/Linalg/sparse_vector.mlir @@ -4,6 +4,8 @@ // RUN: FileCheck %s --check-prefix=CHECK-VEC1 // RUN: mlir-opt %s -test-sparsification="vectorization-strategy=2 ptr-type=2 ind-type=2 vl=16" | \ // RUN: FileCheck %s --check-prefix=CHECK-VEC2 +// RUN: mlir-opt %s -test-sparsification="vectorization-strategy=2 ptr-type=0 ind-type=0 vl=16" | \ +// RUN: FileCheck %s --check-prefix=CHECK-VEC3 #trait_scale_d = { indexing_maps = [ @@ -54,6 +56,18 @@ // CHECK-VEC2: } // CHECK-VEC2: return // +// CHECK-VEC3-LABEL: func @scale_d +// CHECK-VEC3-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-VEC3-DAG: %[[c16:.*]] = constant 16 : index +// CHECK-VEC3-DAG: %[[c1024:.*]] = constant 1024 : index +// CHECK-VEC3: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] { +// CHECK-VEC3: %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32> +// CHECK-VEC3: %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32> +// CHECK-VEC3: %[[m:.*]] = mulf %[[r]], %[[b]] : vector<16xf32> +// CHECK-VEC3: vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32> +// CHECK-VEC3: } +// CHECK-VEC3: return +// func @scale_d(%arga: tensor<1024xf32>, %scale: f32, %argx: tensor<1024xf32>) -> tensor<1024xf32> { %0 = linalg.generic #trait_scale_d ins(%arga: tensor<1024xf32>) @@ -143,6 +157,23 @@ // CHECK-VEC2: } // CHECK-VEC2: return // +// CHECK-VEC3-LABEL: func @mul_s +// CHECK-VEC3-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-VEC3-DAG: %[[c1:.*]] = constant 1 : index +// CHECK-VEC3-DAG: %[[c16:.*]] = constant 16 : index +// CHECK-VEC3: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref +// CHECK-VEC3: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref +// CHECK-VEC3: scf.for %[[i:.*]] = %[[p]] to %[[r]] step %[[c16]] { +// CHECK-VEC3: %[[sub:.*]] = subi %[[r]], %[[i]] : index +// CHECK-VEC3: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> +// CHECK-VEC3: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xindex> into vector<16xindex> +// CHECK-VEC3: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC3: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC3: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32> +// CHECK-VEC3: vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> +// CHECK-VEC3: } +// CHECK-VEC3: return +// func @mul_s(%arga: tensor<1024xf32>, %argb: tensor<1024xf32>, %argx: tensor<1024xf32>) -> tensor<1024xf32> { %0 = linalg.generic #trait_mul_s ins(%arga, %argb: tensor<1024xf32>, tensor<1024xf32>) @@ -177,6 +208,24 @@ // CHECK-VEC2: } // CHECK-VEC2: return // +// CHECK-VEC3-LABEL: func @mul_s_alt +// CHECK-VEC3-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-VEC3-DAG: %[[c1:.*]] = constant 1 : index +// CHECK-VEC3-DAG: %[[c16:.*]] = constant 16 : index +// CHECK-VEC3: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref +// CHECK-VEC3: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref +// CHECK-VEC3: scf.for %[[i:.*]] = %[[p]] to %[[r]] step %[[c16]] { +// CHECK-VEC3: %[[sub:.*]] = subi %[[r]], %[[i]] : index +// CHECK-VEC3: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> +// CHECK-VEC3: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xindex> into vector<16xindex> +// CHECK-VEC3: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC3: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC3: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32> +// CHECK-VEC3: vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> +// CHECK-VEC3: } +// CHECK-VEC3: return +// +// !SparseTensor = type !llvm.ptr func @mul_s_alt(%argA: !SparseTensor, %argB: !SparseTensor, %argx: tensor<1024xf32>) -> tensor<1024xf32> { %arga = linalg.sparse_tensor %argA : !SparseTensor to tensor<1024xf32> @@ -250,6 +299,21 @@ // CHECK-VEC2: %{{.*}} = vector.reduction "add", %[[red]], %{{.*}} : vector<16xf32> into f32 // CHECK-VEC2: return // +// CHECK-VEC3-LABEL: func @reduction_d +// CHECK-VEC3-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-VEC3-DAG: %[[c16:.*]] = constant 16 : index +// CHECK-VEC3-DAG: %[[c1024:.*]] = constant 1024 : index +// CHECK-VEC3-DAG: %[[v0:.*]] = constant dense<0.000000e+00> : vector<16xf32> +// CHECK-VEC3: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[v0]]) -> (vector<16xf32>) { +// CHECK-VEC3: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32> +// CHECK-VEC3: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32> +// CHECK-VEC3: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32> +// CHECK-VEC3: %[[a:.*]] = addf %[[red_in]], %[[m]] : vector<16xf32> +// CHECK-VEC3: scf.yield %[[a]] : vector<16xf32> +// CHECK-VEC3: } +// CHECK-VEC3: %{{.*}} = vector.reduction "add", %[[red]], %{{.*}} : vector<16xf32> into f32 +// CHECK-VEC3: return +// func @reduction_d(%arga: tensor<1024xf32>, %argb: tensor<1024xf32>, %argx: tensor) -> tensor { %0 = linalg.generic #trait_reduction_d ins(%arga, %argb: tensor<1024xf32>, tensor<1024xf32>) @@ -383,6 +447,27 @@ // CHECK-VEC2: } // CHECK-VEC2: return // +// CHECK-VEC3-LABEL: func @mul_ds +// CHECK-VEC3-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-VEC3-DAG: %[[c1:.*]] = constant 1 : index +// CHECK-VEC3-DAG: %[[c16:.*]] = constant 16 : index +// CHECK-VEC3-DAG: %[[c512:.*]] = constant 512 : index +// CHECK-VEC3: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] { +// CHECK-VEC3: %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref +// CHECK-VEC3: %[[a:.*]] = addi %[[i]], %[[c1]] : index +// CHECK-VEC3: %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref +// CHECK-VEC3: scf.for %[[j:.*]] = %[[p]] to %[[r]] step %[[c16]] { +// CHECK-VEC3: %[[sub:.*]] = subi %[[r]], %[[j]] : index +// CHECK-VEC3: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> +// CHECK-VEC3: %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xindex> into vector<16xindex> +// CHECK-VEC3: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC3: %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC3: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32> +// CHECK-VEC3: vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> +// CHECK-VEC3: } +// CHECK-VEC3: } +// CHECK-VEC3: return +// func @mul_ds(%arga: tensor<512x1024xf32>, %argb: tensor<512x1024xf32>, %argx: tensor<512x1024xf32>) -> tensor<512x1024xf32> { %0 = linalg.generic #trait_mul_ds ins(%arga, %argb: tensor<512x1024xf32>, tensor<512x1024xf32>)