diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -234,10 +234,28 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, ValueRange subs, bool codegen, Value vmask, SmallVectorImpl &idxs) { + unsigned d = 0; + unsigned dim = subs.size(); for (auto sub : subs) { - // Invariant/loop indices simply pass through. - if (sub.dyn_cast() || + bool innermost = ++d == dim; + // Invariant subscripts in outer dimensions simply pass through. + // Note that we rely on LICM to hoist loads where all subscripts + // are invariant in the innermost loop. + if (sub.getDefiningOp() && sub.getDefiningOp()->getBlock() != &forOp.getRegion().front()) { + if (innermost) + return false; + if (codegen) + idxs.push_back(sub); + continue; // success so far + } + // Invariant block arguments (including outer loop indices) in outer + // dimensions simply pass through. Direct loop indices in the + // innermost loop simply pass through as well. + if (auto barg = sub.dyn_cast()) { + bool invariant = barg.getOwner() != &forOp.getRegion().front(); + if (invariant == innermost) + return false; if (codegen) idxs.push_back(sub); continue; // success so far @@ -264,6 +282,8 @@ // which creates the potential of incorrect address calculations in the // unlikely case we need such extremely large offsets. if (auto load = cast.getDefiningOp()) { + if (!innermost) + return false; if (codegen) { SmallVector idxs2(load.getIndices()); // no need to analyze Location loc = forOp.getLoc(); @@ -286,9 +306,11 @@ if (auto load = cast.getDefiningOp()) { Value inv = load.getOperand(0); Value idx = load.getOperand(1); - if (!inv.dyn_cast() && + if (inv.getDefiningOp() && inv.getDefiningOp()->getBlock() != &forOp.getRegion().front() && idx.dyn_cast()) { + if (!innermost) + return false; if (codegen) idxs.push_back( rewriter.create(forOp.getLoc(), inv, idx)); diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_concat.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_concat.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_vector_concat.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt %s --sparse-compiler="enable-runtime-library=false vl=2 reassociate-fp-reductions=true enable-index-optimizations=true" + +#MAT_D_C = #sparse_tensor.encoding<{ + dimLevelType = ["dense", "compressed"] +}> + +#MAT_C_C_P = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed" ], + dimOrdering = affine_map<(i,j) -> (j,i)> +}> + +#MAT_C_D_P = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "dense" ], + dimOrdering = affine_map<(i,j) -> (j,i)> +}> + +// +// Ensures only last loop is vectorized +// (vectorizing the others would crash). +// +// CHECK-LABEL: llvm.func @foo +// CHECK: llvm.intr.masked.load +// CHECK: llvm.intr.masked.scatter +// +func.func @foo(%arg0: tensor<2x4xf64, #MAT_C_C_P>, + %arg1: tensor<3x4xf64, #MAT_C_D_P>, + %arg2: tensor<4x4xf64, #MAT_D_C>) -> tensor<9x4xf64> { + %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index} + : tensor<2x4xf64, #MAT_C_C_P>, tensor<3x4xf64, #MAT_C_D_P>, tensor<4x4xf64, #MAT_D_C> to tensor<9x4xf64> + return %0 : tensor<9x4xf64> +}