diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -50,6 +50,16 @@ return false; } +/// Helper test for invariant value (defined outside given block). +static bool isInvariantValue(Value val, Block *block) { + return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block; +} + +/// Helper test for invariant argument (defined outside given block). +static bool isInvariantArg(BlockArgument arg, Block *block) { + return arg.getOwner() != block; +} + /// Constructs vector type for element type. static VectorType vectorType(VL vl, Type etp) { unsigned numScalableDims = vl.enableVLAVectorization; @@ -236,13 +246,15 @@ Value vmask, SmallVectorImpl &idxs) { unsigned d = 0; unsigned dim = subs.size(); + Block *block = &forOp.getRegion().front(); for (auto sub : subs) { bool innermost = ++d == dim; // Invariant subscripts in outer dimensions simply pass through. // Note that we rely on LICM to hoist loads where all subscripts // are invariant in the innermost loop. - if (sub.getDefiningOp() && - sub.getDefiningOp()->getBlock() != &forOp.getRegion().front()) { + // Example: + // a[inv][i] for inv + if (isInvariantValue(sub, block)) { if (innermost) return false; if (codegen) @@ -252,9 +264,10 @@ // Invariant block arguments (including outer loop indices) in outer // dimensions simply pass through. Direct loop indices in the // innermost loop simply pass through as well. - if (auto barg = sub.dyn_cast()) { - bool invariant = barg.getOwner() != &forOp.getRegion().front(); - if (invariant == innermost) + // Example: + // a[i][j] for both i and j + if (auto arg = sub.dyn_cast()) { + if (isInvariantArg(arg, block) == innermost) return false; if (codegen) idxs.push_back(sub); @@ -281,6 +294,8 @@ // values, there is no good way to state that the indices are unsigned, // which creates the potential of incorrect address calculations in the // unlikely case we need such extremely large offsets. + // Example: + // a[ ind[i] ] if (auto load = cast.getDefiningOp()) { if (!innermost) return false; @@ -303,18 +318,20 @@ continue; // success so far } // Address calculation 'i = add inv, idx' (after LICM). + // Example: + // a[base + i] if (auto load = cast.getDefiningOp()) { Value inv = load.getOperand(0); Value idx = load.getOperand(1); - if (inv.getDefiningOp() && - inv.getDefiningOp()->getBlock() != &forOp.getRegion().front() && - idx.dyn_cast()) { - if (!innermost) - return false; - if (codegen) - idxs.push_back( - rewriter.create(forOp.getLoc(), inv, idx)); - continue; // success so far + if (isInvariantValue(inv, block)) { + if (auto arg = idx.dyn_cast()) { + if (isInvariantArg(arg, block) || !innermost) + return false; + if (codegen) + idxs.push_back( + rewriter.create(forOp.getLoc(), inv, idx)); + continue; // success so far + } } } return false; @@ -389,7 +406,8 @@ } // Something defined outside the loop-body is invariant. Operation *def = exp.getDefiningOp(); - if (def->getBlock() != &forOp.getRegion().front()) { + Block *block = &forOp.getRegion().front(); + if (def->getBlock() != block) { if (codegen) vexp = genVectorInvariantValue(rewriter, vl, exp); return true; @@ -450,6 +468,17 @@ vx) && vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask, vy)) { + // We only accept shift-by-invariant (where the same shift factor applies + // to all packed elements). In the vector dialect, this is still + // represented with an expanded vector at the right-hand-side, however, + // so that we do not have to special case the code generation. + if (isa(def) || isa(def) || + isa(def)) { + Value shiftFactor = def->getOperand(1); + if (!isInvariantValue(shiftFactor, block)) + return false; + } + // Generate code. BINOP(arith::MulFOp) BINOP(arith::MulIOp) BINOP(arith::DivFOp) @@ -462,8 +491,10 @@ BINOP(arith::AndIOp) BINOP(arith::OrIOp) BINOP(arith::XOrIOp) + BINOP(arith::ShLIOp) + BINOP(arith::ShRUIOp) + BINOP(arith::ShRSIOp) // TODO: complex? - // TODO: shift by invariant? } } return false; diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir @@ -17,6 +17,8 @@ // CHECK-DAG: %[[C1:.*]] = arith.constant dense<2.000000e+00> : vector<8xf32> // CHECK-DAG: %[[C2:.*]] = arith.constant dense<1.000000e+00> : vector<8xf32> // CHECK-DAG: %[[C3:.*]] = arith.constant dense<255> : vector<8xi64> +// CHECK-DAG: %[[C4:.*]] = arith.constant dense<4> : vector<8xi32> +// CHECK-DAG: %[[C5:.*]] = arith.constant dense<1> : vector<8xi32> // CHECK: scf.for // CHECK: %[[VAL_14:.*]] = vector.load // CHECK: %[[VAL_15:.*]] = math.absf %[[VAL_14]] : vector<8xf32> @@ -38,8 +40,11 @@ // CHECK: %[[VAL_31:.*]] = arith.andi %[[VAL_30]], %[[C3]] : vector<8xi64> // CHECK: %[[VAL_32:.*]] = arith.trunci %[[VAL_31]] : vector<8xi64> to vector<8xi16> // CHECK: %[[VAL_33:.*]] = arith.extsi %[[VAL_32]] : vector<8xi16> to vector<8xi32> -// CHECK: %[[VAL_34:.*]] = arith.uitofp %[[VAL_33]] : vector<8xi32> to vector<8xf32> -// CHECK: vector.store %[[VAL_34]] +// CHECK: %[[VAL_34:.*]] = arith.shrsi %[[VAL_33]], %[[C4]] : vector<8xi32> +// CHECK: %[[VAL_35:.*]] = arith.shrui %[[VAL_34]], %[[C4]] : vector<8xi32> +// CHECK: %[[VAL_36:.*]] = arith.shli %[[VAL_35]], %[[C5]] : vector<8xi32> +// CHECK: %[[VAL_37:.*]] = arith.uitofp %[[VAL_36]] : vector<8xi32> to vector<8xf32> +// CHECK: vector.store %[[VAL_37]] // CHECK: } func.func @vops(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024xf32, #DenseVector>) -> tensor<1024xf32> { @@ -47,6 +52,8 @@ %o = arith.constant 1.0 : f32 %c = arith.constant 2.0 : f32 %i = arith.constant 255 : i64 + %s = arith.constant 4 : i32 + %t = arith.constant 1 : i32 %0 = linalg.generic #trait ins(%arga, %argb: tensor<1024xf32, #DenseVector>, tensor<1024xf32, #DenseVector>) outs(%init: tensor<1024xf32>) { @@ -69,8 +76,11 @@ %15 = arith.andi %14, %i : i64 %16 = arith.trunci %15 : i64 to i16 %17 = arith.extsi %16 : i16 to i32 - %18 = arith.uitofp %17 : i32 to f32 - linalg.yield %18 : f32 + %18 = arith.shrsi %17, %s : i32 + %19 = arith.shrui %18, %s : i32 + %20 = arith.shli %19, %t : i32 + %21 = arith.uitofp %20 : i32 to f32 + linalg.yield %21 : f32 } -> tensor<1024xf32> return %0 : tensor<1024xf32> }