diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -48,15 +48,16 @@ /// Sparsification options. struct SparsificationOptions { SparsificationOptions(SparseParallelizationStrategy p, - SparseVectorizationStrategy v, unsigned vl) - : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl) { - } + SparseVectorizationStrategy v, unsigned vl, bool e) + : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl), + enableSIMDIndex32(e) {} SparsificationOptions() : SparsificationOptions(SparseParallelizationStrategy::kNone, - SparseVectorizationStrategy::kNone, 1u) {} + SparseVectorizationStrategy::kNone, 1u, false) {} SparseParallelizationStrategy parallelizationStrategy; SparseVectorizationStrategy vectorizationStrategy; unsigned vectorLength; + bool enableSIMDIndex32; }; /// Sets up sparsification rewriting rules with the given options. diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -21,6 +21,16 @@ "sparse_tensor::SparseTensorDialect", "vector::VectorDialect", ]; + let options = [ + Option<"parallelization", "parallelization-strategy", "int32_t", "0", + "Set the parallelization strategy">, + Option<"vectorization", "vectorization-strategy", "int32_t", "0", + "Set the vectorization strategy">, + Option<"vectorLength", "vl", "int32_t", "1", + "Set the vector length">, + Option<"enableSIMDIndex32", "enable-simd-index32", "bool", "false", + "Enable i32 indexing into vectors (for efficiency)"> + ]; } def SparseTensorConversion : Pass<"sparse-tensor-conversion", "ModuleOp"> { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -35,17 +35,6 @@ SparsificationPass(const SparsificationPass &pass) : SparsificationBase() {} - Option parallelization{ - *this, "parallelization-strategy", - llvm::cl::desc("Set the parallelization strategy"), llvm::cl::init(0)}; - - Option vectorization{ - *this, "vectorization-strategy", - llvm::cl::desc("Set the vectorization strategy"), llvm::cl::init(0)}; - - Option vectorLength{ - *this, "vl", llvm::cl::desc("Set the vector length"), llvm::cl::init(1)}; - /// Returns parallelization strategy given on command line. SparseParallelizationStrategy parallelOption() { switch (parallelization) { @@ -79,7 +68,7 @@ RewritePatternSet patterns(ctx); // Translate strategy flags to strategy options. SparsificationOptions options(parallelOption(), vectorOption(), - vectorLength); + vectorLength, enableSIMDIndex32); // Apply rewriting. populateSparsificationPatterns(patterns, options); vector::populateVectorToVectorCanonicalizationPatterns(patterns); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -768,9 +768,9 @@ // zero extend the vector to an index width. For 8-bit and 16-bit values, // an 32-bit index width suffices. For 32-bit values, zero extending the // elements into 64-bit loses some performance since the 32-bit indexed - // gather/scatter is more efficient than the 64-bit index variant (in - // the future, we could introduce a flag that states the negative space - // of 32-bit indices is unused). For 64-bit values, there is no good way + // gather/scatter is more efficient than the 64-bit index variant (if the + // negative 32-bit index space is unused, the enableSIMDIndex32 flag can + // preserve this performance)). For 64-bit values, there is no good way // to state that the indices are unsigned, with creates the potential of // incorrect address calculations in the unlikely case we need such // extremely large offsets. @@ -780,7 +780,8 @@ if (etp.getIntOrFloatBitWidth() < 32) vload = rewriter.create( loc, vload, vectorType(codegen, rewriter.getIntegerType(32))); - else if (etp.getIntOrFloatBitWidth() < 64) + else if (etp.getIntOrFloatBitWidth() < 64 && + !codegen.options.enableSIMDIndex32) vload = rewriter.create( loc, vload, vectorType(codegen, rewriter.getIntegerType(64))); } diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir @@ -4,6 +4,8 @@ // RUN: FileCheck %s --check-prefix=CHECK-VEC1 // RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16" | \ // RUN: FileCheck %s --check-prefix=CHECK-VEC2 +// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16 enable-simd-index32=true" | \ +// RUN: FileCheck %s --check-prefix=CHECK-VEC3 #DenseVector = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }> @@ -148,6 +150,27 @@ // CHECK-VEC2: } // CHECK-VEC2: return // +// CHECK-VEC3-LABEL: func @mul_s +// CHECK-VEC3-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-VEC3-DAG: %[[c1:.*]] = constant 1 : index +// CHECK-VEC3-DAG: %[[c16:.*]] = constant 16 : index +// CHECK-VEC3: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref +// CHECK-VEC3: %[[a:.*]] = zexti %[[p]] : i32 to i64 +// CHECK-VEC3: %[[q:.*]] = index_cast %[[a]] : i64 to index +// CHECK-VEC3: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref +// CHECK-VEC3: %[[b:.*]] = zexti %[[r]] : i32 to i64 +// CHECK-VEC3: %[[s:.*]] = index_cast %[[b]] : i64 to index +// CHECK-VEC3: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] { +// CHECK-VEC3: %[[sub:.*]] = subi %{{.*}}, %[[i]] : index +// CHECK-VEC3: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> +// CHECK-VEC3: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> +// CHECK-VEC3: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC3: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC3: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32> +// CHECK-VEC3: vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> +// CHECK-VEC3: } +// CHECK-VEC3: return +// func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %argx: tensor<1024xf32>) -> tensor<1024xf32> { %0 = linalg.generic #trait_mul_s ins(%arga, %argb: tensor<1024xf32, #SparseVector>, tensor<1024xf32>) @@ -310,6 +333,31 @@ // CHECK-VEC2: } // CHECK-VEC2: return // +// CHECK-VEC3-LABEL: func @mul_ds +// CHECK-VEC3-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-VEC3-DAG: %[[c1:.*]] = constant 1 : index +// CHECK-VEC3-DAG: %[[c16:.*]] = constant 16 : index +// CHECK-VEC3-DAG: %[[c512:.*]] = constant 512 : index +// CHECK-VEC3: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] { +// CHECK-VEC3: %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref +// CHECK-VEC3: %[[a:.*]] = zexti %[[p]] : i32 to i64 +// CHECK-VEC3: %[[q:.*]] = index_cast %[[a]] : i64 to index +// CHECK-VEC3: %[[a:.*]] = addi %[[i]], %[[c1]] : index +// CHECK-VEC3: %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref +// CHECK-VEC3: %[[b:.*]] = zexti %[[r]] : i32 to i64 +// CHECK-VEC3: %[[s:.*]] = index_cast %[[b]] : i64 to index +// CHECK-VEC3: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] { +// CHECK-VEC3: %[[sub:.*]] = subi %[[s]], %[[j]] : index +// CHECK-VEC3: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> +// CHECK-VEC3: %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> +// CHECK-VEC3: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC3: %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC3: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32> +// CHECK-VEC3: vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> +// CHECK-VEC3: } +// CHECK-VEC3: } +// CHECK-VEC3: return +// func @mul_ds(%arga: tensor<512x1024xf32, #SparseMatrix>, %argb: tensor<512x1024xf32>, %argx: tensor<512x1024xf32>) -> tensor<512x1024xf32> { %0 = linalg.generic #trait_mul_ds ins(%arga, %argb: tensor<512x1024xf32, #SparseMatrix>, tensor<512x1024xf32>)