diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -81,6 +81,30 @@ llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); } +/// Generates a constant zero of the given type. +inline static Value constantZero(ConversionPatternRewriter &rewriter, + Location loc, Type t) { + return rewriter.create(loc, t, rewriter.getZeroAttr(t)); +} + +/// Generates a constant of `index` type. +inline static Value constantIndex(ConversionPatternRewriter &rewriter, + Location loc, unsigned i) { + return rewriter.create(loc, i); +} + +/// Generates a constant of `i64` type. +inline static Value constantI64(ConversionPatternRewriter &rewriter, + Location loc, int64_t i) { + return rewriter.create(loc, i, 64); +} + +/// Generates a constant of `i32` type. +inline static Value constantI32(ConversionPatternRewriter &rewriter, + Location loc, int32_t i) { + return rewriter.create(loc, i, 32); +} + /// Returns integers of given width and values as a constant tensor. /// We cast the static shape into a dynamic shape to ensure that the /// method signature remains uniform across different tensor dimensions. @@ -161,18 +185,14 @@ unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); unsigned primary = getPrimaryTypeEncoding(resType.getElementType()); assert(primary); - params.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(secPtr))); - params.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(secInd))); - params.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(primary))); + params.push_back(constantI64(rewriter, loc, secPtr)); + params.push_back(constantI64(rewriter, loc, secInd)); + params.push_back(constantI64(rewriter, loc, primary)); // User action and pointer. Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8)); if (!ptr) ptr = rewriter.create(loc, pTp); - params.push_back(rewriter.create( - loc, rewriter.getI32IntegerAttr(action))); + params.push_back(constantI32(rewriter, loc, action)); params.push_back(ptr); // Generate the call to create new tensor. StringRef name = "newSparseTensor"; @@ -182,19 +202,13 @@ return call.getResult(0); } -/// Generates a constant zero of the given type. -static Value getZero(ConversionPatternRewriter &rewriter, Location loc, - Type t) { - return rewriter.create(loc, rewriter.getZeroAttr(t)); -} - /// Generates the comparison `v != 0` where `v` is of numeric type `t`. /// For floating types, we use the "unordered" comparator (i.e., returns /// true if `v` is NaN). static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc, Value v) { Type t = v.getType(); - Value zero = getZero(rewriter, loc, t); + Value zero = constantZero(rewriter, loc, t); if (t.isa()) return rewriter.create(loc, arith::CmpFPredicate::UNE, v, zero); @@ -221,8 +235,7 @@ rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); unsigned i = 0; for (auto iv : ivs) { - Value idx = - rewriter.create(loc, rewriter.getIndexAttr(i++)); + Value idx = constantIndex(rewriter, loc, i++); rewriter.create(loc, iv, ind, idx); } return val; @@ -289,8 +302,7 @@ unsigned rank) { Location loc = op->getLoc(); for (unsigned i = 0; i < rank; i++) { - Value idx = - rewriter.create(loc, rewriter.getIndexAttr(i)); + Value idx = constantIndex(rewriter, loc, i); Value val = rewriter.create(loc, indices, ValueRange{ivs[0], idx}); val = @@ -308,8 +320,7 @@ int64_t rank) { auto indexTp = rewriter.getIndexType(); auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp); - Value arg = - rewriter.create(loc, rewriter.getIndexAttr(rank)); + Value arg = constantIndex(rewriter, loc, rank); return rewriter.create(loc, memTp, ValueRange{arg}); } @@ -352,8 +363,7 @@ StringRef name = "sparseDimSize"; SmallVector params; params.push_back(adaptor.getOperands()[0]); - params.push_back(rewriter.create( - op.getLoc(), rewriter.getIndexAttr(idx))); + params.push_back(constantIndex(rewriter, op.getLoc(), idx)); rewriter.replaceOpWithNewOp( op, resType, getFunc(op, name, resType, params), params); return success(); @@ -437,10 +447,8 @@ SmallVector lo; SmallVector hi; SmallVector st; - Value zero = - rewriter.create(loc, rewriter.getIndexAttr(0)); - Value one = - rewriter.create(loc, rewriter.getIndexAttr(1)); + Value zero = constantIndex(rewriter, loc, 0); + Value one = constantIndex(rewriter, loc, 1); auto indicesValues = genSplitSparseConstant(rewriter, op, src); bool isCOOConstant = indicesValues.hasValue(); Value indices;