diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -53,6 +53,7 @@ }]; let constructor = "mlir::createSparsificationPass()"; let dependentDialects = [ + "AffineDialect", "LLVM::LLVMDialect", "memref::MemRefDialect", "scf::SCFDialect", diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -10,10 +10,12 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/SparseTensor/Utils/Merger.h" @@ -354,7 +356,13 @@ // during vector execution. Here we rely on subsequent loop optimizations to // avoid executing the mask in all iterations, for example, by splitting the // loop into an unconditional vector loop and a scalar cleanup loop. - Value end = rewriter.create(loc, hi, iv); + auto minMap = AffineMap::get( + /*dimCount=*/2, /*symbolCount=*/1, + {rewriter.getAffineSymbolExpr(0), + rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, + rewriter.getContext()); + Value end = + rewriter.createOrFold(loc, minMap, ValueRange{hi, iv, step}); return rewriter.create(loc, mtp, end); } diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir @@ -1,26 +1,14 @@ -// RUN: mlir-opt %s -sparsification="vectorization-strategy=0 vl=16" | \ +// RUN: mlir-opt %s -sparsification="vectorization-strategy=0 vl=16" -split-input-file | \ // RUN: FileCheck %s --check-prefix=CHECK-VEC0 -// RUN: mlir-opt %s -sparsification="vectorization-strategy=1 vl=16" | \ +// RUN: mlir-opt %s -sparsification="vectorization-strategy=1 vl=16" -split-input-file | \ // RUN: FileCheck %s --check-prefix=CHECK-VEC1 -// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16" | \ +// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16" -split-input-file | \ // RUN: FileCheck %s --check-prefix=CHECK-VEC2 -// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16 enable-simd-index32=true" | \ +// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16 enable-simd-index32=true" -split-input-file | \ // RUN: FileCheck %s --check-prefix=CHECK-VEC3 #DenseVector = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }> -#SparseVector = #sparse_tensor.encoding<{ - dimLevelType = [ "compressed" ], - pointerBitWidth = 32, - indexBitWidth = 32 -}> - -#SparseMatrix = #sparse_tensor.encoding<{ - dimLevelType = [ "dense", "compressed" ], - pointerBitWidth = 32, - indexBitWidth = 32 -}> - #trait_scale_d = { indexing_maps = [ affine_map<(i) -> (i)>, // a @@ -31,7 +19,7 @@ } // -// CHECK-VEC0-LABEL: func @scale_d +// CHECK-VEC0: func @scale_d // CHECK-VEC0-DAG: %[[c0:.*]] = constant 0 : index // CHECK-VEC0-DAG: %[[c1:.*]] = constant 1 : index // CHECK-VEC0-DAG: %[[c1024:.*]] = constant 1024 : index @@ -42,7 +30,7 @@ // CHECK-VEC0: } // CHECK-VEC0: return // -// CHECK-VEC1-LABEL: func @scale_d +// CHECK-VEC1: func @scale_d // CHECK-VEC1-DAG: %[[c0:.*]] = constant 0 : index // CHECK-VEC1-DAG: %[[c16:.*]] = constant 16 : index // CHECK-VEC1-DAG: %[[c1024:.*]] = constant 1024 : index @@ -54,7 +42,7 @@ // CHECK-VEC1: } // CHECK-VEC1: return // -// CHECK-VEC2-LABEL: func @scale_d +// CHECK-VEC2: func @scale_d // CHECK-VEC2-DAG: %[[c0:.*]] = constant 0 : index // CHECK-VEC2-DAG: %[[c16:.*]] = constant 16 : index // CHECK-VEC2-DAG: %[[c1024:.*]] = constant 1024 : index @@ -77,6 +65,14 @@ return %0 : tensor<1024xf32> } +// ----- + +#SparseVector = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed" ], + pointerBitWidth = 32, + indexBitWidth = 32 +}> + #trait_mul_s = { indexing_maps = [ affine_map<(i) -> (i)>, // a @@ -88,7 +84,7 @@ } // -// CHECK-VEC0-LABEL: func @mul_s +// CHECK-VEC0: func @mul_s // CHECK-VEC0-DAG: %[[c0:.*]] = constant 0 : index // CHECK-VEC0-DAG: %[[c1:.*]] = constant 1 : index // CHECK-VEC0: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref @@ -108,7 +104,7 @@ // CHECK-VEC0: } // CHECK-VEC0: return // -// CHECK-VEC1-LABEL: func @mul_s +// CHECK-VEC1: func @mul_s // CHECK-VEC1-DAG: %[[c0:.*]] = constant 0 : index // CHECK-VEC1-DAG: %[[c1:.*]] = constant 1 : index // CHECK-VEC1: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref @@ -128,7 +124,8 @@ // CHECK-VEC1: } // CHECK-VEC1: return // -// CHECK-VEC2-LABEL: func @mul_s +// CHECK-VEC2: #[[map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1) +// CHECK-VEC2: func @mul_s // CHECK-VEC2-DAG: %[[c0:.*]] = constant 0 : index // CHECK-VEC2-DAG: %[[c1:.*]] = constant 1 : index // CHECK-VEC2-DAG: %[[c16:.*]] = constant 16 : index @@ -139,7 +136,7 @@ // CHECK-VEC2: %[[b:.*]] = zexti %[[r]] : i32 to i64 // CHECK-VEC2: %[[s:.*]] = index_cast %[[b]] : i64 to index // CHECK-VEC2: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] { -// CHECK-VEC2: %[[sub:.*]] = subi %[[s]], %[[i]] : index +// CHECK-VEC2: %[[sub:.*]] = affine.min #[[map]](%[[s]], %[[i]])[%[[c16]]] // CHECK-VEC2: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> // CHECK-VEC2: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> // CHECK-VEC2: %[[zi:.*]] = zexti %[[li]] : vector<16xi32> to vector<16xi64> @@ -150,7 +147,8 @@ // CHECK-VEC2: } // CHECK-VEC2: return // -// CHECK-VEC3-LABEL: func @mul_s +// CHECK-VEC3: #[[map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1) +// CHECK-VEC3: func @mul_s // CHECK-VEC3-DAG: %[[c0:.*]] = constant 0 : index // CHECK-VEC3-DAG: %[[c1:.*]] = constant 1 : index // CHECK-VEC3-DAG: %[[c16:.*]] = constant 16 : index @@ -161,7 +159,7 @@ // CHECK-VEC3: %[[b:.*]] = zexti %[[r]] : i32 to i64 // CHECK-VEC3: %[[s:.*]] = index_cast %[[b]] : i64 to index // CHECK-VEC3: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] { -// CHECK-VEC3: %[[sub:.*]] = subi %{{.*}}, %[[i]] : index +// CHECK-VEC3: %[[sub:.*]] = affine.min #[[map]](%[[s]], %[[i]])[%[[c16]]] // CHECK-VEC3: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> // CHECK-VEC3: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> // CHECK-VEC3: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> @@ -182,6 +180,10 @@ return %0 : tensor<1024xf32> } +// ----- + +#DenseVector = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }> + #trait_reduction_d = { indexing_maps = [ affine_map<(i) -> (i)>, // a @@ -193,7 +195,7 @@ } // -// CHECK-VEC0-LABEL: func @reduction_d +// CHECK-VEC0: func @reduction_d // CHECK-VEC0-DAG: %[[c0:.*]] = constant 0 : index // CHECK-VEC0-DAG: %[[c1:.*]] = constant 1 : index // CHECK-VEC0-DAG: %[[c1024:.*]] = constant 1024 : index @@ -206,7 +208,7 @@ // CHECK-VEC0: } // CHECK-VEC0: return // -// CHECK-VEC1-LABEL: func @reduction_d +// CHECK-VEC1: func @reduction_d // CHECK-VEC1-DAG: %[[c0:.*]] = constant 0 : index // CHECK-VEC1-DAG: %[[c16:.*]] = constant 16 : index // CHECK-VEC1-DAG: %[[c1024:.*]] = constant 1024 : index @@ -221,7 +223,7 @@ // CHECK-VEC1: %{{.*}} = vector.reduction "add", %[[red]], %{{.*}} : vector<16xf32> into f32 // CHECK-VEC1: return // -// CHECK-VEC2-LABEL: func @reduction_d +// CHECK-VEC2: func @reduction_d // CHECK-VEC2-DAG: %[[c0:.*]] = constant 0 : index // CHECK-VEC2-DAG: %[[c16:.*]] = constant 16 : index // CHECK-VEC2-DAG: %[[c1024:.*]] = constant 1024 : index @@ -248,6 +250,14 @@ return %0 : tensor } +// ----- + +#SparseMatrix = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + pointerBitWidth = 32, + indexBitWidth = 32 +}> + #trait_mul_ds = { indexing_maps = [ affine_map<(i,j) -> (i,j)>, // A @@ -259,7 +269,7 @@ } // -// CHECK-VEC0-LABEL: func @mul_ds +// CHECK-VEC0: func @mul_ds // CHECK-VEC0-DAG: %[[c0:.*]] = constant 0 : index // CHECK-VEC0-DAG: %[[c1:.*]] = constant 1 : index // CHECK-VEC0-DAG: %[[c512:.*]] = constant 512 : index @@ -283,7 +293,7 @@ // CHECK-VEC0: } // CHECK-VEC0: return // -// CHECK-VEC1-LABEL: func @mul_ds +// CHECK-VEC1: func @mul_ds // CHECK-VEC1-DAG: %[[c0:.*]] = constant 0 : index // CHECK-VEC1-DAG: %[[c1:.*]] = constant 1 : index // CHECK-VEC1-DAG: %[[c512:.*]] = constant 512 : index @@ -307,7 +317,8 @@ // CHECK-VEC1: } // CHECK-VEC1: return // -// CHECK-VEC2-LABEL: func @mul_ds +// CHECK-VEC2: #[[map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1) +// CHECK-VEC2: func @mul_ds // CHECK-VEC2-DAG: %[[c0:.*]] = constant 0 : index // CHECK-VEC2-DAG: %[[c1:.*]] = constant 1 : index // CHECK-VEC2-DAG: %[[c16:.*]] = constant 16 : index @@ -321,7 +332,7 @@ // CHECK-VEC2: %[[b:.*]] = zexti %[[r]] : i32 to i64 // CHECK-VEC2: %[[s:.*]] = index_cast %[[b]] : i64 to index // CHECK-VEC2: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] { -// CHECK-VEC2: %[[sub:.*]] = subi %[[s]], %[[j]] : index +// CHECK-VEC2: %[[sub:.*]] = affine.min #[[map]](%[[s]], %[[j]])[%[[c16]]] // CHECK-VEC2: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> // CHECK-VEC2: %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> // CHECK-VEC2: %[[zj:.*]] = zexti %[[lj]] : vector<16xi32> to vector<16xi64> @@ -333,7 +344,8 @@ // CHECK-VEC2: } // CHECK-VEC2: return // -// CHECK-VEC3-LABEL: func @mul_ds +// CHECK-VEC3: #[[map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1) +// CHECK-VEC3: func @mul_ds // CHECK-VEC3-DAG: %[[c0:.*]] = constant 0 : index // CHECK-VEC3-DAG: %[[c1:.*]] = constant 1 : index // CHECK-VEC3-DAG: %[[c16:.*]] = constant 16 : index @@ -347,7 +359,7 @@ // CHECK-VEC3: %[[b:.*]] = zexti %[[r]] : i32 to i64 // CHECK-VEC3: %[[s:.*]] = index_cast %[[b]] : i64 to index // CHECK-VEC3: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] { -// CHECK-VEC3: %[[sub:.*]] = subi %[[s]], %[[j]] : index +// CHECK-VEC3: %[[sub:.*]] = affine.min #[[map]](%[[s]], %[[j]])[%[[c16]]] // CHECK-VEC3: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> // CHECK-VEC3: %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> // CHECK-VEC3: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir @@ -0,0 +1,64 @@ +// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16" -for-loop-peeling -canonicalize | \ +// RUN: FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed" ], + pointerBitWidth = 32, + indexBitWidth = 32 +}> + +#trait_mul_s = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a + affine_map<(i) -> (i)>, // b + affine_map<(i) -> (i)> // x (out) + ], + iterator_types = ["parallel"], + doc = "x(i) = a(i) * b(i)" +} + +// CHECK-DAG: #[[map0:.*]] = affine_map<()[s0, s1] -> (s0 + ((-s0 + s1) floordiv 16) * 16)> +// CHECK-DAG: #[[map1:.*]] = affine_map<()[s0, s1] -> ((-s0 + s1) mod 16)> +// CHECK: func @mul_s +// CHECK-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-DAG: %[[c1:.*]] = constant 1 : index +// CHECK-DAG: %[[c16:.*]] = constant 16 : index +// CHECK: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref +// CHECK: %[[a:.*]] = zexti %[[p]] : i32 to i64 +// CHECK: %[[q:.*]] = index_cast %[[a]] : i64 to index +// CHECK: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref +// CHECK: %[[b:.*]] = zexti %[[r]] : i32 to i64 +// CHECK: %[[s:.*]] = index_cast %[[b]] : i64 to index +// CHECK: %[[boundary:.*]] = affine.apply #[[map0]]()[%[[q]], %[[s]]] +// CHECK: scf.for %[[i:.*]] = %[[q]] to %[[boundary]] step %[[c16]] { +// CHECK: %[[mask:.*]] = vector.constant_mask [16] : vector<16xi1> +// CHECK: %[[li:.*]] = vector.load %{{.*}}[%[[i]]] : memref, vector<16xi32> +// CHECK: %[[zi:.*]] = zexti %[[li]] : vector<16xi32> to vector<16xi64> +// CHECK: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref, vector<16xf32> +// CHECK: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32> +// CHECK: vector.scatter %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> +// CHECK: } +// CHECK: %[[has_more:.*]] = cmpi slt, %[[boundary]], %[[s]] : index +// CHECK: scf.if %[[has_more]] { +// CHECK: %[[sub:.*]] = affine.apply #[[map1]]()[%[[q]], %[[s]]] +// CHECK: %[[mask2:.*]] = vector.create_mask %[[sub]] : vector<16xi1> +// CHECK: %[[li2:.*]] = vector.maskedload %{{.*}}[%[[boundary]]], %[[mask2]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> +// CHECK: %[[zi2:.*]] = zexti %[[li2]] : vector<16xi32> to vector<16xi64> +// CHECK: %[[la2:.*]] = vector.maskedload %{{.*}}[%[[boundary]]], %[[mask2]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK: %[[lb2:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi2]]], %[[mask2]], %{{.*}} : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK: %[[m2:.*]] = mulf %[[la2]], %[[lb2]] : vector<16xf32> +// CHECK: vector.scatter %{{.*}}[%[[c0]]] [%[[zi2]]], %[[mask2]], %[[m2]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> +// CHECK: } +// CHECK: return +// +func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %argx: tensor<1024xf32>) -> tensor<1024xf32> { + %0 = linalg.generic #trait_mul_s + ins(%arga, %argb: tensor<1024xf32, #SparseVector>, tensor<1024xf32>) + outs(%argx: tensor<1024xf32>) { + ^bb(%a: f32, %b: f32, %x: f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<1024xf32> + return %0 : tensor<1024xf32> +}