diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -53,6 +53,7 @@ }]; let constructor = "mlir::createSparsificationPass()"; let dependentDialects = [ + "AffineDialect", "LLVM::LLVMDialect", "memref::MemRefDialect", "scf::SCFDialect", diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -10,10 +10,12 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/SparseTensor/Utils/Merger.h" @@ -354,7 +356,13 @@ // during vector execution. Here we rely on subsequent loop optimizations to // avoid executing the mask in all iterations, for example, by splitting the // loop into an unconditional vector loop and a scalar cleanup loop. - Value end = rewriter.create(loc, hi, iv); + auto minMap = AffineMap::get( + /*dimCount=*/2, /*symbolCount=*/1, + {rewriter.getAffineSymbolExpr(0), + rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, + rewriter.getContext()); + Value end = + rewriter.createOrFold(loc, minMap, ValueRange{hi, iv, step}); return rewriter.create(loc, mtp, end); } diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir @@ -6,6 +6,8 @@ // RUN: FileCheck %s --check-prefix=CHECK-VEC2 // RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16 enable-simd-index32=true" | \ // RUN: FileCheck %s --check-prefix=CHECK-VEC3 +// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16" -for-loop-peeling -canonicalize | \ +// RUN: FileCheck %s --check-prefix=CHECK-VEC2-PEELED #DenseVector = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }> @@ -139,7 +141,7 @@ // CHECK-VEC2: %[[b:.*]] = zexti %[[r]] : i32 to i64 // CHECK-VEC2: %[[s:.*]] = index_cast %[[b]] : i64 to index // CHECK-VEC2: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] { -// CHECK-VEC2: %[[sub:.*]] = subi %[[s]], %[[i]] : index +// CHECK-VEC2: %[[sub:.*]] = affine.min #{{.*}}(%[[s]], %[[i]])[%[[c16]]] // CHECK-VEC2: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> // CHECK-VEC2: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> // CHECK-VEC2: %[[zi:.*]] = zexti %[[li]] : vector<16xi32> to vector<16xi64> @@ -150,6 +152,39 @@ // CHECK-VEC2: } // CHECK-VEC2: return // +// CHECK-VEC2-PEELED-LABEL: func @mul_s +// CHECK-VEC2-PEELED-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-VEC2-PEELED-DAG: %[[c1:.*]] = constant 1 : index +// CHECK-VEC2-PEELED-DAG: %[[c16:.*]] = constant 16 : index +// CHECK-VEC2-PEELED: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref +// CHECK-VEC2-PEELED: %[[a:.*]] = zexti %[[p]] : i32 to i64 +// CHECK-VEC2-PEELED: %[[q:.*]] = index_cast %[[a]] : i64 to index +// CHECK-VEC2-PEELED: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref +// CHECK-VEC2-PEELED: %[[b:.*]] = zexti %[[r]] : i32 to i64 +// CHECK-VEC2-PEELED: %[[s:.*]] = index_cast %[[b]] : i64 to index +// CHECK-VEC2-PEELED: %[[boundary:.*]] = affine.apply #{{.*}}[%[[q]], %[[s]]] +// CHECK-VEC2-PEELED: scf.for %[[i:.*]] = %[[q]] to %[[boundary]] step %[[c16]] { +// CHECK-VEC2-PEELED: %[[mask:.*]] = vector.constant_mask [16] : vector<16xi1> +// CHECK-VEC2-PEELED: %[[li:.*]] = vector.load %{{.*}}[%[[i]]] : memref, vector<16xi32> +// CHECK-VEC2-PEELED: %[[zi:.*]] = zexti %[[li]] : vector<16xi32> to vector<16xi64> +// CHECK-VEC2-PEELED: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref, vector<16xf32> +// CHECK-VEC2-PEELED: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC2-PEELED: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32> +// CHECK-VEC2-PEELED: vector.scatter %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> +// CHECK-VEC2-PEELED: } +// CHECK-VEC2-PEELED: %[[has_more:.*]] = cmpi slt, %[[boundary]], %[[s]] : index +// CHECK-VEC2-PEELED: scf.if %[[has_more]] { +// CHECK-VEC2-PEELED: %[[sub:.*]] = affine.apply #{{.*}}[%[[q]], %[[s]]] +// CHECK-VEC2-PEELED: %[[mask2:.*]] = vector.create_mask %[[sub]] : vector<16xi1> +// CHECK-VEC2-PEELED: %[[li2:.*]] = vector.maskedload %{{.*}}[%[[boundary]]], %[[mask2]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> +// CHECK-VEC2-PEELED: %[[zi2:.*]] = zexti %[[li2]] : vector<16xi32> to vector<16xi64> +// CHECK-VEC2-PEELED: %[[la2:.*]] = vector.maskedload %{{.*}}[%[[boundary]]], %[[mask2]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC2-PEELED: %[[lb2:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi2]]], %[[mask2]], %{{.*}} : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC2-PEELED: %[[m2:.*]] = mulf %[[la2]], %[[lb2]] : vector<16xf32> +// CHECK-VEC2-PEELED: vector.scatter %{{.*}}[%[[c0]]] [%[[zi2]]], %[[mask2]], %[[m2]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> +// CHECK-VEC2-PEELED: } +// CHECK-VEC2-PEELED: return +// // CHECK-VEC3-LABEL: func @mul_s // CHECK-VEC3-DAG: %[[c0:.*]] = constant 0 : index // CHECK-VEC3-DAG: %[[c1:.*]] = constant 1 : index @@ -161,7 +196,7 @@ // CHECK-VEC3: %[[b:.*]] = zexti %[[r]] : i32 to i64 // CHECK-VEC3: %[[s:.*]] = index_cast %[[b]] : i64 to index // CHECK-VEC3: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] { -// CHECK-VEC3: %[[sub:.*]] = subi %{{.*}}, %[[i]] : index +// CHECK-VEC3: %[[sub:.*]] = affine.min #{{.*}}(%[[s]], %[[i]])[%[[c16]]] // CHECK-VEC3: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> // CHECK-VEC3: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> // CHECK-VEC3: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> @@ -321,7 +356,7 @@ // CHECK-VEC2: %[[b:.*]] = zexti %[[r]] : i32 to i64 // CHECK-VEC2: %[[s:.*]] = index_cast %[[b]] : i64 to index // CHECK-VEC2: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] { -// CHECK-VEC2: %[[sub:.*]] = subi %[[s]], %[[j]] : index +// CHECK-VEC2: %[[sub:.*]] = affine.min #{{.*}}(%[[s]], %[[j]])[%[[c16]]] // CHECK-VEC2: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> // CHECK-VEC2: %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> // CHECK-VEC2: %[[zj:.*]] = zexti %[[lj]] : vector<16xi32> to vector<16xi64> @@ -347,7 +382,7 @@ // CHECK-VEC3: %[[b:.*]] = zexti %[[r]] : i32 to i64 // CHECK-VEC3: %[[s:.*]] = index_cast %[[b]] : i64 to index // CHECK-VEC3: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] { -// CHECK-VEC3: %[[sub:.*]] = subi %[[s]], %[[j]] : index +// CHECK-VEC3: %[[sub:.*]] = affine.min #{{.*}}(%[[s]], %[[j]])[%[[c16]]] // CHECK-VEC3: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> // CHECK-VEC3: %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> // CHECK-VEC3: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32>