diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -76,9 +76,11 @@ /// for the corresponding `SmallVector` object. using LatSetId = unsigned; +namespace detail { /// A constant serving as the canonically invalid identifier, regardless /// of the identifier type. static constexpr unsigned kInvalidId = -1u; +} // namespace detail /// Tensor expression. Represents an MLIR expression in tensor index notation. struct TensorExp final { @@ -269,13 +271,13 @@ /// Constructs a new tensor expression, and returns its identifier. /// The type of the `e0` argument varies according to the value of the /// `k` argument, as described by the `TensorExp` ctor. - ExprId addExp(TensorExp::Kind k, unsigned e0, ExprId e1 = kInvalidId, + ExprId addExp(TensorExp::Kind k, unsigned e0, ExprId e1 = detail::kInvalidId, Value v = Value(), Operation *op = nullptr); ExprId addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op = nullptr) { - return addExp(k, e, kInvalidId, v, op); + return addExp(k, e, detail::kInvalidId, v, op); } ExprId addExp(TensorExp::Kind k, Value v, Operation *op = nullptr) { - return addExp(k, kInvalidId, kInvalidId, v, op); + return addExp(k, detail::kInvalidId, detail::kInvalidId, v, op); } /// Constructs a new iteration lattice point, and returns its identifier. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -133,7 +133,7 @@ // void startReduc(ExprId exp, Value val); - bool isReduc() const { return redExp != kInvalidId; } + bool isReduc() const { return redExp != detail::kInvalidId; } void updateReduc(Value val); Value getReduc() const { return redVal; } Value endReduc(); @@ -142,7 +142,7 @@ Value getValidLexInsert() const { return redValidLexInsert; } void startCustomReduc(ExprId exp); - bool isCustomReduc() const { return redCustom != kInvalidId; } + bool isCustomReduc() const { return redCustom != detail::kInvalidId; } Value getCustomRedId(); void endCustomReduc(); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -55,8 +55,8 @@ : linalgOp(linop), sparseOptions(opts), latticeMerger(numTensors, numLoops, numFilterLoops), loopEmitter(), topSort(), sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(), - expFilled(), expAdded(), expCount(), redVal(), redExp(kInvalidId), - redCustom(kInvalidId), redValidLexInsert() {} + expFilled(), expAdded(), expCount(), redVal(), redExp(detail::kInvalidId), + redCustom(detail::kInvalidId), redValidLexInsert() {} LogicalResult CodegenEnv::initTensorExp() { // Builds the tensor expression for the Linalg operation in SSA form. @@ -277,7 +277,7 @@ //===----------------------------------------------------------------------===// void CodegenEnv::startReduc(ExprId exp, Value val) { - assert(!isReduc() && exp != kInvalidId); + assert(!isReduc() && exp != detail::kInvalidId); redExp = exp; updateReduc(val); } @@ -296,7 +296,7 @@ Value val = redVal; redVal = val; latticeMerger.clearExprValue(redExp); - redExp = kInvalidId; + redExp = detail::kInvalidId; return val; } @@ -311,7 +311,7 @@ } void CodegenEnv::startCustomReduc(ExprId exp) { - assert(!isCustomReduc() && exp != kInvalidId); + assert(!isCustomReduc() && exp != detail::kInvalidId); redCustom = exp; } @@ -322,5 +322,5 @@ void CodegenEnv::endCustomReduc() { assert(isCustomReduc()); - redCustom = kInvalidId; + redCustom = detail::kInvalidId; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1111,7 +1111,7 @@ linalg::GenericOp op = env.op(); Location loc = op.getLoc(); - if (e == kInvalidId) + if (e == ::mlir::sparse_tensor::detail::kInvalidId) return Value(); const TensorExp &exp = env.exp(e); const auto kind = exp.kind; @@ -1146,11 +1146,11 @@ /// Hoists loop invariant tensor loads for which indices have been exhausted. static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId ldx, bool atStart) { - if (exp == kInvalidId) + if (exp == ::mlir::sparse_tensor::detail::kInvalidId) return; if (env.exp(exp).kind == TensorExp::Kind::kTensor) { // Inspect tensor indices. - bool isAtLoop = ldx == kInvalidId; + bool isAtLoop = ldx == ::mlir::sparse_tensor::detail::kInvalidId; linalg::GenericOp op = env.op(); OpOperand &t = op->getOpOperand(env.exp(exp).tensor); auto map = op.getMatchingIndexingMap(&t); @@ -1715,7 +1715,8 @@ // Construct iteration lattices for current loop index, with L0 at top. const LoopId idx = env.topSortAt(at); - const LoopId ldx = at == 0 ? kInvalidId : env.topSortAt(at - 1); + const LoopId ldx = at == 0 ? ::mlir::sparse_tensor::detail::kInvalidId + : env.topSortAt(at - 1); const LatSetId lts = env.merger().optimizeSet(env.merger().buildLattices(exp, idx)); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -103,14 +103,14 @@ switch (kind) { // Leaf. case TensorExp::Kind::kTensor: - assert(x != kInvalidId && y == kInvalidId && !v && !o); + assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); tensor = x; return; case TensorExp::Kind::kInvariant: - assert(x == kInvalidId && y == kInvalidId && v && !o); + assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o); return; case TensorExp::Kind::kLoopVar: - assert(x != kInvalidId && y == kInvalidId && !v && !o); + assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); loop = x; return; // Unary operations. @@ -134,7 +134,7 @@ case TensorExp::Kind::kNegI: case TensorExp::Kind::kCIm: case TensorExp::Kind::kCRe: - assert(x != kInvalidId && y == kInvalidId && !v && !o); + assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); children.e0 = x; children.e1 = y; return; @@ -149,20 +149,20 @@ case TensorExp::Kind::kCastIdx: case TensorExp::Kind::kTruncI: case TensorExp::Kind::kBitCast: - assert(x != kInvalidId && y == kInvalidId && v && !o); + assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o); children.e0 = x; children.e1 = y; return; case TensorExp::Kind::kBinaryBranch: case TensorExp::Kind::kSelect: - assert(x != kInvalidId && y == kInvalidId && !v && o); + assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o); children.e0 = x; children.e1 = y; return; case TensorExp::Kind::kUnary: // No assertion on y can be made, as the branching paths involve both // a unary (`mapSet`) and binary (`disjSet`) pathway. - assert(x != kInvalidId && !v && o); + assert(x != detail::kInvalidId && !v && o); children.e0 = x; children.e1 = y; return; @@ -186,13 +186,13 @@ case TensorExp::Kind::kShrS: case TensorExp::Kind::kShrU: case TensorExp::Kind::kShlI: - assert(x != kInvalidId && y != kInvalidId && !v && !o); + assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o); children.e0 = x; children.e1 = y; return; case TensorExp::Kind::kBinary: case TensorExp::Kind::kReduce: - assert(x != kInvalidId && y != kInvalidId && !v && o); + assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o); children.e0 = x; children.e1 = y; return;