diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/CodegenUtils.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/CodegenUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/CodegenUtils.h @@ -0,0 +1,159 @@ +//===- CodegenUtils.h - Utilities for generating MLIR -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines utilities for generating MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_CODEGENUTILS_H_ +#define MLIR_DIALECT_SPARSETENSOR_UTILS_CODEGENUTILS_H_ + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/ExecutionEngine/SparseTensorUtils.h" +#include "mlir/IR/Builders.h" + +namespace mlir { +// Forward references. +class Location; +class Type; +class Value; + +namespace sparse_tensor { + +//===----------------------------------------------------------------------===// +// ExecutionEngine/SparseTensorUtils helper functions. +//===----------------------------------------------------------------------===// + +/// Converts an overhead storage bitwidth to its internal type-encoding. +OverheadType overheadTypeEncoding(unsigned width); + +/// Converts the internal type-encoding for overhead storage to an mlir::Type. +Type getOverheadType(Builder &builder, OverheadType ot); + +/// Returns the mlir::Type for pointer overhead storage. +Type getPointerOverheadType(Builder &builder, + const SparseTensorEncodingAttr &enc); + +/// Returns the mlir::Type for index overhead storage. +Type getIndexOverheadType(Builder &builder, + const SparseTensorEncodingAttr &enc); + +/// Converts a primary storage type to its internal type-encoding. +PrimaryType primaryTypeEncoding(Type elemTp); + +/// Converts the IR's dimension level type to its internal type-encoding. +DimLevelType dimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt); + +//===----------------------------------------------------------------------===// +// Misc code generators. +//===----------------------------------------------------------------------===// + +/// Generates a 1-valued attribute of the given type. This supports +/// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`, +/// for unsupported types we raise `llvm_unreachable` rather than +/// returning a null attribute. +Attribute getOneAttr(Builder &builder, Type tp); + +/// Generates the comparison `v != 0` where `v` is of numeric type. +/// For floating types, we use the "unordered" comparator (i.e., returns +/// true if `v` is NaN). +Value genIsNonzero(OpBuilder &builder, Location loc, Value v); + +//===----------------------------------------------------------------------===// +// Constant generators. +// +// All these functions are just wrappers to improve code legibility; +// therefore, we mark them as `inline` to avoid introducing any additional +// overhead due to the legibility. +//===----------------------------------------------------------------------===// + +/// Generates a 0-valued constant of the given type. In addition to +/// the scalar types (`FloatType`, `IndexType`, `IntegerType`), this also +/// works for `RankedTensorType` and `VectorType` (for which it generates +/// a constant `DenseElementsAttr` of zeros). +inline Value constantZero(OpBuilder &builder, Location loc, Type tp) { + return builder.create(loc, tp, builder.getZeroAttr(tp)); +} + +/// Generates a 1-valued constant of the given type. This supports all +/// the same types as `constantZero`. +inline Value constantOne(OpBuilder &builder, Location loc, Type tp) { + return builder.create(loc, tp, getOneAttr(builder, tp)); +} + +/// Generates a constant of `index` type. +inline Value constantIndex(OpBuilder &builder, Location loc, int64_t i) { + return builder.create(loc, i); +} + +/// Generates a constant of `i32` type. +inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) { + return builder.create(loc, i, 32); +} + +/// Generates a constant of `i16` type. +inline Value constantI16(OpBuilder &builder, Location loc, int16_t i) { + return builder.create(loc, i, 16); +} + +/// Generates a constant of `i8` type. +inline Value constantI8(OpBuilder &builder, Location loc, int8_t i) { + return builder.create(loc, i, 8); +} + +/// Generates a constant of `i1` type. +inline Value constantI1(OpBuilder &builder, Location loc, bool b) { + return builder.create(loc, b, 1); +} + +/// Generates a constant of the given `Action`. +inline Value constantAction(OpBuilder &builder, Location loc, Action action) { + return constantI32(builder, loc, static_cast(action)); +} + +/// Generates a constant of the internal type-encoding for overhead storage. +inline Value constantOverheadTypeEncoding(OpBuilder &builder, Location loc, + unsigned width) { + return constantI32(builder, loc, + static_cast(overheadTypeEncoding(width))); +} + +/// Generates a constant of the internal type-encoding for pointer +/// overhead storage. +inline Value constantPointerTypeEncoding(OpBuilder &builder, Location loc, + const SparseTensorEncodingAttr &enc) { + return constantOverheadTypeEncoding(builder, loc, enc.getPointerBitWidth()); +} + +/// Generates a constant of the internal type-encoding for index overhead +/// storage. +inline Value constantIndexTypeEncoding(OpBuilder &builder, Location loc, + const SparseTensorEncodingAttr &enc) { + return constantOverheadTypeEncoding(builder, loc, enc.getIndexBitWidth()); +} + +/// Generates a constant of the internal type-encoding for primary storage. +inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc, + Type elemTp) { + return constantI32(builder, loc, + static_cast(primaryTypeEncoding(elemTp))); +} + +/// Generates a constant of the internal dimension level type encoding. +inline Value +constantDimLevelTypeEncoding(OpBuilder &builder, Location loc, + SparseTensorEncodingAttr::DimLevelType dlt) { + return constantI8(builder, loc, + static_cast(dimLevelTypeEncoding(dlt))); +} + +} // namespace sparse_tensor +} // namespace mlir + +#endif // MLIR_DIALECT_SPARSETENSOR_UTILS_CODEGENUTILS_H_ diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Dialect/SparseTensor/Utils/CodegenUtils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/ExecutionEngine/SparseTensorUtils.h" @@ -35,113 +36,6 @@ // Helper methods. //===----------------------------------------------------------------------===// -/// Generates a constant zero of the given type. -inline static Value constantZero(ConversionPatternRewriter &rewriter, - Location loc, Type t) { - return rewriter.create(loc, t, rewriter.getZeroAttr(t)); -} - -/// Generates a constant of `index` type. -inline static Value constantIndex(ConversionPatternRewriter &rewriter, - Location loc, int64_t i) { - return rewriter.create(loc, i); -} - -/// Generates a constant of `i32` type. -inline static Value constantI32(ConversionPatternRewriter &rewriter, - Location loc, int32_t i) { - return rewriter.create(loc, i, 32); -} - -/// Generates a constant of `i8` type. -inline static Value constantI8(ConversionPatternRewriter &rewriter, - Location loc, int8_t i) { - return rewriter.create(loc, i, 8); -} - -/// Generates a constant of the given `Action`. -static Value constantAction(ConversionPatternRewriter &rewriter, Location loc, - Action action) { - return constantI32(rewriter, loc, static_cast(action)); -} - -/// Generates a constant of the internal type encoding for overhead storage. -static Value constantOverheadTypeEncoding(ConversionPatternRewriter &rewriter, - Location loc, unsigned width) { - OverheadType sec; - switch (width) { - default: - sec = OverheadType::kU64; - break; - case 32: - sec = OverheadType::kU32; - break; - case 16: - sec = OverheadType::kU16; - break; - case 8: - sec = OverheadType::kU8; - break; - } - return constantI32(rewriter, loc, static_cast(sec)); -} - -/// Generates a constant of the internal type encoding for pointer -/// overhead storage. -static Value constantPointerTypeEncoding(ConversionPatternRewriter &rewriter, - Location loc, - SparseTensorEncodingAttr &enc) { - return constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth()); -} - -/// Generates a constant of the internal type encoding for index overhead -/// storage. -static Value constantIndexTypeEncoding(ConversionPatternRewriter &rewriter, - Location loc, - SparseTensorEncodingAttr &enc) { - return constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth()); -} - -/// Generates a constant of the internal type encoding for primary storage. -static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter, - Location loc, Type tp) { - PrimaryType primary; - if (tp.isF64()) - primary = PrimaryType::kF64; - else if (tp.isF32()) - primary = PrimaryType::kF32; - else if (tp.isInteger(64)) - primary = PrimaryType::kI64; - else if (tp.isInteger(32)) - primary = PrimaryType::kI32; - else if (tp.isInteger(16)) - primary = PrimaryType::kI16; - else if (tp.isInteger(8)) - primary = PrimaryType::kI8; - else - llvm_unreachable("Unknown element type"); - return constantI32(rewriter, loc, static_cast(primary)); -} - -/// Generates a constant of the internal dimension level type encoding. -static Value -constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc, - SparseTensorEncodingAttr::DimLevelType dlt) { - DimLevelType dlt2; - switch (dlt) { - case SparseTensorEncodingAttr::DimLevelType::Dense: - dlt2 = DimLevelType::kDense; - break; - case SparseTensorEncodingAttr::DimLevelType::Compressed: - dlt2 = DimLevelType::kCompressed; - break; - case SparseTensorEncodingAttr::DimLevelType::Singleton: - dlt2 = DimLevelType::kSingleton; - break; - } - return constantI8(rewriter, loc, static_cast(dlt2)); -} - /// Returns a function reference (first hit also inserts into module). Sets /// the "_emit_c_interface" on the function declaration when requested, /// so that LLVM lowering generates a wrapper function that takes care @@ -306,22 +200,6 @@ params.push_back(ptr); } -/// Generates the comparison `v != 0` where `v` is of numeric type `t`. -/// For floating types, we use the "unordered" comparator (i.e., returns -/// true if `v` is NaN). -static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc, - Value v) { - Type t = v.getType(); - Value zero = constantZero(rewriter, loc, t); - if (t.isa()) - return rewriter.create(loc, arith::CmpFPredicate::UNE, v, - zero); - if (t.isIntOrIndex()) - return rewriter.create(loc, arith::CmpIPredicate::ne, v, - zero); - llvm_unreachable("Unknown element type"); -} - /// Generates the code to read the value from tensor[ivs], and conditionally /// stores the indices ivs to the memory in ind. The generated code looks like /// the following and the insertion point after this routine is inside the diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/SCF/Transforms.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Dialect/SparseTensor/Utils/CodegenUtils.h" #include "mlir/Dialect/SparseTensor/Utils/Merger.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" @@ -400,26 +401,16 @@ case kNoReduc: break; case kSum: - case kXor: { + case kXor: // Initialize reduction vector to: | 0 | .. | 0 | r | - Attribute zero = rewriter.getZeroAttr(vtp); - Value vec = rewriter.create(loc, vtp, zero); return rewriter.create( - loc, r, vec, rewriter.create(loc, 0)); - } - case kProduct: { + loc, r, constantZero(rewriter, loc, vtp), + constantIndex(rewriter, loc, 0)); + case kProduct: // Initialize reduction vector to: | 1 | .. | 1 | r | - Type etp = vtp.getElementType(); - Attribute one; - if (etp.isa()) - one = rewriter.getFloatAttr(etp, 1.0); - else - one = rewriter.getIntegerAttr(etp, 1); - Value vec = rewriter.create( - loc, vtp, DenseElementsAttr::get(vtp, one)); return rewriter.create( - loc, r, vec, rewriter.create(loc, 0)); - } + loc, r, constantOne(rewriter, loc, vtp), + constantIndex(rewriter, loc, 0)); case kAnd: case kOr: // Initialize reduction vector to: | r | .. | r | r | @@ -447,13 +438,6 @@ // Sparse compiler synthesis methods (statements and expressions). //===----------------------------------------------------------------------===// -/// Maps sparse integer option to actual integral storage type. -static Type genIntType(PatternRewriter &rewriter, unsigned width) { - if (width == 0) - return rewriter.getIndexType(); - return rewriter.getIntegerType(width); -} - /// Generates buffer for the output tensor. Note that all sparse kernels /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), /// the output buffer is already initialized to all zeroes and only nonzeroes @@ -478,10 +462,8 @@ // materializes into the computation, we need to preserve the zero // initialization assumption of all sparse output buffers. if (isMaterializing(tensor)) { - Type tp = denseTp.getElementType(); Value alloc = rewriter.create(loc, denseTp, args); - Value zero = - rewriter.create(loc, tp, rewriter.getZeroAttr(tp)); + Value zero = constantZero(rewriter, loc, denseTp.getElementType()); rewriter.create(loc, zero, alloc); return alloc; } @@ -516,11 +498,11 @@ // Handle sparse storage schemes. if (merger.isDim(tensor, idx, Dim::kSparse)) { auto dynShape = {ShapedType::kDynamicSize}; - auto ptrTp = MemRefType::get( - dynShape, genIntType(rewriter, enc.getPointerBitWidth())); - auto indTp = MemRefType::get( - dynShape, genIntType(rewriter, enc.getIndexBitWidth())); - Value dim = rewriter.create(loc, d); + auto ptrTp = + MemRefType::get(dynShape, getPointerOverheadType(rewriter, enc)); + auto indTp = + MemRefType::get(dynShape, getIndexOverheadType(rewriter, enc)); + Value dim = constantIndex(rewriter, loc, d); // Generate sparse primitives to obtains pointer and indices. codegen.pointers[tensor][idx] = rewriter.create(loc, ptrTp, t->get(), dim); @@ -551,7 +533,7 @@ genOutputBuffer(codegen, rewriter, op, denseTp, args); } else if (t == codegen.sparseOut) { // True sparse output needs a lexIdx array. - Value rank = rewriter.create(loc, op.getRank(t)); + Value rank = constantIndex(rewriter, loc, op.getRank(t)); auto dynShape = {ShapedType::kDynamicSize}; auto memTp = MemRefType::get(dynShape, rewriter.getIndexType()); codegen.lexIdx = rewriter.create(loc, memTp, rank); @@ -579,7 +561,7 @@ static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, Value iv, Value lo, Value hi, Value step) { Location loc = iv.getLoc(); - VectorType mtp = vectorType(codegen, genIntType(rewriter, 1)); + VectorType mtp = vectorType(codegen, rewriter.getI1Type()); // Special case if the vector length evenly divides the trip count (for // example, "for i = 0, 128, 16"). A constant all-true mask is generated // so that all subsequent masked memory operations are immediately folded @@ -590,7 +572,7 @@ matchPattern(step, m_Constant(&stepInt))) { if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) return rewriter.create( - loc, mtp, rewriter.create(loc, 1, 1)); + loc, mtp, constantI1(rewriter, loc, true)); } // Otherwise, generate a vector mask that avoids overrunning the upperbound // during vector execution. Here we rely on subsequent loop optimizations to @@ -611,12 +593,11 @@ Value ptr, ArrayRef args) { Location loc = ptr.getLoc(); VectorType vtp = vectorType(codegen, ptr); - Value pass = - rewriter.create(loc, vtp, rewriter.getZeroAttr(vtp)); + Value pass = constantZero(rewriter, loc, vtp); if (args.back().getType().isa()) { SmallVector scalarArgs(args.begin(), args.end()); Value indexVec = args.back(); - scalarArgs.back() = rewriter.create(loc, 0); + scalarArgs.back() = constantIndex(rewriter, loc, 0); return rewriter.create( loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); } @@ -631,7 +612,7 @@ if (args.back().getType().isa()) { SmallVector scalarArgs(args.begin(), args.end()); Value indexVec = args.back(); - scalarArgs.back() = rewriter.create(loc, 0); + scalarArgs.back() = constantIndex(rewriter, loc, 0); rewriter.create(loc, ptr, scalarArgs, indexVec, codegen.curVecMask, rhs); return; @@ -673,7 +654,7 @@ } case AffineExprKind::Constant: { int64_t c = a.cast().getValue(); - return rewriter.create(loc, c); + return constantIndex(rewriter, loc, c); } default: llvm_unreachable("unexpected affine subscript"); @@ -720,8 +701,7 @@ OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; if (t == codegen.sparseOut) { Type tp = getElementTypeOrSelf(t->get().getType()); - return rewriter.create(op.getLoc(), tp, - rewriter.getZeroAttr(tp)); + return constantZero(rewriter, op.getLoc(), tp); } // Actual load. SmallVector args; @@ -783,11 +763,11 @@ if (!etp.isa()) { if (etp.getIntOrFloatBitWidth() < 32) vload = rewriter.create( - loc, vload, vectorType(codegen, genIntType(rewriter, 32))); + loc, vload, vectorType(codegen, rewriter.getI32Type())); else if (etp.getIntOrFloatBitWidth() < 64 && !codegen.options.enableSIMDIndex32) vload = rewriter.create( - loc, vload, vectorType(codegen, genIntType(rewriter, 64))); + loc, vload, vectorType(codegen, rewriter.getI64Type())); } return vload; } @@ -798,8 +778,7 @@ Value load = rewriter.create(loc, ptr, s); if (!load.getType().isa()) { if (load.getType().getIntOrFloatBitWidth() < 64) - load = - rewriter.create(loc, load, genIntType(rewriter, 64)); + load = rewriter.create(loc, load, rewriter.getI64Type()); load = rewriter.create(loc, load, rewriter.getIndexType()); } @@ -939,8 +918,8 @@ break; } Value ptr = codegen.pointers[tensor][idx]; - Value one = rewriter.create(loc, 1); - Value p0 = (pat == 0) ? rewriter.create(loc, 0) + Value one = constantIndex(rewriter, loc, 1); + Value p0 = (pat == 0) ? constantIndex(rewriter, loc, 0) : codegen.pidxs[tensor][topSort[pat - 1]]; codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); Value p1 = rewriter.create(loc, p0, one); @@ -953,7 +932,7 @@ } // Initialize the universal dense index. - codegen.loops[idx] = rewriter.create(loc, 0); + codegen.loops[idx] = constantIndex(rewriter, loc, 0); return needsUniv; } @@ -1043,8 +1022,7 @@ Location loc = op.getLoc(); Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; - Value step = - rewriter.create(loc, codegen.curVecLength); + Value step = constantIndex(rewriter, loc, codegen.curVecLength); // Emit a parallel loop. if (isParallel) { @@ -1208,7 +1186,7 @@ for (; pat != 0; pat--) if (codegen.pidxs[tensor][topSort[pat - 1]]) break; - Value p = (pat == 0) ? rewriter.create(loc, 0) + Value p = (pat == 0) ? constantIndex(rewriter, loc, 0) : codegen.pidxs[tensor][topSort[pat - 1]]; codegen.pidxs[tensor][idx] = genAddress( codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); @@ -1217,7 +1195,7 @@ // Move the insertion indices in lexicographic index order. if (codegen.sparseOut) { - Value pos = rewriter.create(loc, at); + Value pos = constantIndex(rewriter, loc, at); rewriter.create(loc, codegen.loops[idx], codegen.lexIdx, pos); } @@ -1247,7 +1225,7 @@ // after the if-statements more closely resembles code generated by TACO. unsigned o = 0; SmallVector operands; - Value one = rewriter.create(loc, 1); + Value one = constantIndex(rewriter, loc, 1); for (unsigned b = 0, be = induction.size(); b < be; b++) { if (induction[b] && merger.isDim(b, Dim::kSparse)) { unsigned tensor = merger.tensor(b); @@ -1311,7 +1289,7 @@ clause = rewriter.create(loc, arith::CmpIPredicate::eq, op1, op2); } else { - clause = rewriter.create(loc, 1, 1); // true + clause = constantI1(rewriter, loc, true); } cond = cond ? rewriter.create(loc, cond, clause) : clause; } diff --git a/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRSparseTensorUtils Merger.cpp + CodegenUtils.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor diff --git a/mlir/lib/Dialect/SparseTensor/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Utils/CodegenUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Utils/CodegenUtils.cpp @@ -0,0 +1,128 @@ +//===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SparseTensor/Utils/CodegenUtils.h" + +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" + +using namespace mlir::sparse_tensor; + +//===----------------------------------------------------------------------===// +// ExecutionEngine/SparseTensorUtils helper functions. +//===----------------------------------------------------------------------===// + +OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) { + switch (width) { + default: + return OverheadType::kU64; + case 32: + return OverheadType::kU32; + case 16: + return OverheadType::kU16; + case 8: + return OverheadType::kU8; + } +} + +mlir::Type mlir::sparse_tensor::getOverheadType(mlir::Builder &builder, + OverheadType ot) { + switch (ot) { + case OverheadType::kU64: + return builder.getIntegerType(64); + case OverheadType::kU32: + return builder.getIntegerType(32); + case OverheadType::kU16: + return builder.getIntegerType(16); + case OverheadType::kU8: + return builder.getIntegerType(8); + } + llvm_unreachable("Unknown OverheadType"); +} + +mlir::Type mlir::sparse_tensor::getPointerOverheadType( + mlir::Builder &builder, const SparseTensorEncodingAttr &enc) { + // NOTE(wrengr): This workaround will be fixed in D115010. + unsigned width = enc.getPointerBitWidth(); + if (width == 0) + return builder.getIndexType(); + return getOverheadType(builder, overheadTypeEncoding(width)); +} + +mlir::Type +mlir::sparse_tensor::getIndexOverheadType(mlir::Builder &builder, + const SparseTensorEncodingAttr &enc) { + // NOTE(wrengr): This workaround will be fixed in D115010. + unsigned width = enc.getIndexBitWidth(); + if (width == 0) + return builder.getIndexType(); + return getOverheadType(builder, overheadTypeEncoding(width)); +} + +PrimaryType mlir::sparse_tensor::primaryTypeEncoding(mlir::Type elemTp) { + if (elemTp.isF64()) + return PrimaryType::kF64; + if (elemTp.isF32()) + return PrimaryType::kF32; + if (elemTp.isInteger(64)) + return PrimaryType::kI64; + if (elemTp.isInteger(32)) + return PrimaryType::kI32; + if (elemTp.isInteger(16)) + return PrimaryType::kI16; + if (elemTp.isInteger(8)) + return PrimaryType::kI8; + llvm_unreachable("Unknown primary type"); +} + +DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding( + SparseTensorEncodingAttr::DimLevelType dlt) { + switch (dlt) { + case SparseTensorEncodingAttr::DimLevelType::Dense: + return DimLevelType::kDense; + case SparseTensorEncodingAttr::DimLevelType::Compressed: + return DimLevelType::kCompressed; + case SparseTensorEncodingAttr::DimLevelType::Singleton: + return DimLevelType::kSingleton; + } + llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); +} + +//===----------------------------------------------------------------------===// +// Misc code generators. +//===----------------------------------------------------------------------===// + +mlir::Attribute mlir::sparse_tensor::getOneAttr(mlir::Builder &builder, + mlir::Type tp) { + if (tp.isa()) + return builder.getFloatAttr(tp, 1.0); + if (tp.isa()) + return builder.getIndexAttr(1); + if (auto intTp = tp.dyn_cast()) + return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1)); + if (tp.isa()) { + auto shapedTp = tp.cast(); + if (auto one = getOneAttr(builder, shapedTp.getElementType())) + return DenseElementsAttr::get(shapedTp, one); + } + llvm_unreachable("Unsupported attribute type"); +} + +mlir::Value mlir::sparse_tensor::genIsNonzero(mlir::OpBuilder &builder, + mlir::Location loc, + mlir::Value v) { + mlir::Type tp = v.getType(); + mlir::Value zero = constantZero(builder, loc, tp); + if (tp.isa()) + return builder.create(loc, arith::CmpFPredicate::UNE, v, + zero); + if (tp.isIntOrIndex()) + return builder.create(loc, arith::CmpIPredicate::ne, v, + zero); + llvm_unreachable("Non-numeric type"); +} diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/SparseTensor/Utils/Merger.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/SparseTensor/Utils/CodegenUtils.h" #include "mlir/IR/Operation.h" #include "llvm/Support/Debug.h" @@ -666,10 +667,7 @@ return rewriter.create(loc, v0); case kNegI: // no negi in std return rewriter.create( - loc, - rewriter.create(loc, v0.getType(), - rewriter.getZeroAttr(v0.getType())), - v0); + loc, constantZero(rewriter, loc, v0.getType()), v0); case kTruncF: return rewriter.create(loc, v0, inferType(e, v0)); case kExtF: diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1797,7 +1797,6 @@ ":SideEffectInterfaces", ":SparseTensorAttrDefsIncGen", ":SparseTensorOpsIncGen", - ":SparseTensorUtils", ":StandardOps", "//llvm:Support", ], @@ -1806,16 +1805,17 @@ cc_library( name = "SparseTensorUtils", srcs = glob(["lib/Dialect/SparseTensor/Utils/*.cpp"]), - hdrs = glob(["include/mlir/Dialect/SparseTensor/Utils/*.h"]), + hdrs = glob([ + "include/mlir/Dialect/SparseTensor/Utils/*.h", + ]) + [ + "include/mlir/ExecutionEngine/SparseTensorUtils.h", + ], includes = ["include"], deps = [ ":ArithmeticDialect", ":IR", ":LinalgOps", - ":SideEffectInterfaces", - ":SparseTensorAttrDefsIncGen", - ":SparseTensorOpsIncGen", - ":StandardOps", + ":SparseTensor", "//llvm:Support", ], ) @@ -1825,6 +1825,7 @@ srcs = glob(["lib/Dialect/SparseTensor/Transforms/*.cpp"]), hdrs = [ "include/mlir/Dialect/SparseTensor/Transforms/Passes.h", + "include/mlir/Dialect/SparseTensor/Utils/CodegenUtils.h", "include/mlir/ExecutionEngine/SparseTensorUtils.h", ], includes = ["include"],