diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -14,6 +14,7 @@ // //===----------------------------------------------------------------------===// +#include "../Utils/CodegenUtils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" @@ -39,113 +40,6 @@ // Helper methods. //===----------------------------------------------------------------------===// -/// Generates a constant zero of the given type. -inline static Value constantZero(ConversionPatternRewriter &rewriter, - Location loc, Type t) { - return rewriter.create(loc, t, rewriter.getZeroAttr(t)); -} - -/// Generates a constant of `index` type. -inline static Value constantIndex(ConversionPatternRewriter &rewriter, - Location loc, int64_t i) { - return rewriter.create(loc, i); -} - -/// Generates a constant of `i32` type. -inline static Value constantI32(ConversionPatternRewriter &rewriter, - Location loc, int32_t i) { - return rewriter.create(loc, i, 32); -} - -/// Generates a constant of `i8` type. -inline static Value constantI8(ConversionPatternRewriter &rewriter, - Location loc, int8_t i) { - return rewriter.create(loc, i, 8); -} - -/// Generates a constant of the given `Action`. -static Value constantAction(ConversionPatternRewriter &rewriter, Location loc, - Action action) { - return constantI32(rewriter, loc, static_cast(action)); -} - -/// Generates a constant of the internal type encoding for overhead storage. -static Value constantOverheadTypeEncoding(ConversionPatternRewriter &rewriter, - Location loc, unsigned width) { - OverheadType sec; - switch (width) { - default: - sec = OverheadType::kU64; - break; - case 32: - sec = OverheadType::kU32; - break; - case 16: - sec = OverheadType::kU16; - break; - case 8: - sec = OverheadType::kU8; - break; - } - return constantI32(rewriter, loc, static_cast(sec)); -} - -/// Generates a constant of the internal type encoding for pointer -/// overhead storage. -static Value constantPointerTypeEncoding(ConversionPatternRewriter &rewriter, - Location loc, - SparseTensorEncodingAttr &enc) { - return constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth()); -} - -/// Generates a constant of the internal type encoding for index overhead -/// storage. -static Value constantIndexTypeEncoding(ConversionPatternRewriter &rewriter, - Location loc, - SparseTensorEncodingAttr &enc) { - return constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth()); -} - -/// Generates a constant of the internal type encoding for primary storage. -static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter, - Location loc, Type tp) { - PrimaryType primary; - if (tp.isF64()) - primary = PrimaryType::kF64; - else if (tp.isF32()) - primary = PrimaryType::kF32; - else if (tp.isInteger(64)) - primary = PrimaryType::kI64; - else if (tp.isInteger(32)) - primary = PrimaryType::kI32; - else if (tp.isInteger(16)) - primary = PrimaryType::kI16; - else if (tp.isInteger(8)) - primary = PrimaryType::kI8; - else - llvm_unreachable("Unknown element type"); - return constantI32(rewriter, loc, static_cast(primary)); -} - -/// Generates a constant of the internal dimension level type encoding. -static Value -constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc, - SparseTensorEncodingAttr::DimLevelType dlt) { - DimLevelType dlt2; - switch (dlt) { - case SparseTensorEncodingAttr::DimLevelType::Dense: - dlt2 = DimLevelType::kDense; - break; - case SparseTensorEncodingAttr::DimLevelType::Compressed: - dlt2 = DimLevelType::kCompressed; - break; - case SparseTensorEncodingAttr::DimLevelType::Singleton: - dlt2 = DimLevelType::kSingleton; - break; - } - return constantI8(rewriter, loc, static_cast(dlt2)); -} - /// Returns the equivalent of `void*` for opaque arguments to the /// execution engine. static Type getOpaquePointerType(PatternRewriter &rewriter) { @@ -336,22 +230,6 @@ params.push_back(ptr); } -/// Generates the comparison `v != 0` where `v` is of numeric type `t`. -/// For floating types, we use the "unordered" comparator (i.e., returns -/// true if `v` is NaN). -static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc, - Value v) { - Type t = v.getType(); - Value zero = constantZero(rewriter, loc, t); - if (t.isa()) - return rewriter.create(loc, arith::CmpFPredicate::UNE, v, - zero); - if (t.isIntOrIndex()) - return rewriter.create(loc, arith::CmpIPredicate::ne, v, - zero); - llvm_unreachable("Unknown element type"); -} - /// Generates the code to read the value from tensor[ivs], and conditionally /// stores the indices ivs to the memory in ind. The generated code looks like /// the following and the insertion point after this routine is inside the diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "../Utils/CodegenUtils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -406,26 +407,16 @@ case kNoReduc: break; case kSum: - case kXor: { + case kXor: // Initialize reduction vector to: | 0 | .. | 0 | r | - Attribute zero = rewriter.getZeroAttr(vtp); - Value vec = rewriter.create(loc, vtp, zero); return rewriter.create( - loc, r, vec, rewriter.create(loc, 0)); - } - case kProduct: { + loc, r, constantZero(rewriter, loc, vtp), + constantIndex(rewriter, loc, 0)); + case kProduct: // Initialize reduction vector to: | 1 | .. | 1 | r | - Type etp = vtp.getElementType(); - Attribute one; - if (etp.isa()) - one = rewriter.getFloatAttr(etp, 1.0); - else - one = rewriter.getIntegerAttr(etp, 1); - Value vec = rewriter.create( - loc, vtp, DenseElementsAttr::get(vtp, one)); return rewriter.create( - loc, r, vec, rewriter.create(loc, 0)); - } + loc, r, constantOne(rewriter, loc, vtp), + constantIndex(rewriter, loc, 0)); case kAnd: case kOr: // Initialize reduction vector to: | r | .. | r | r | @@ -453,13 +444,6 @@ // Sparse compiler synthesis methods (statements and expressions). //===----------------------------------------------------------------------===// -/// Maps sparse integer option to actual integral storage type. -static Type genIntType(PatternRewriter &rewriter, unsigned width) { - if (width == 0) - return rewriter.getIndexType(); - return rewriter.getIntegerType(width); -} - /// Generates buffer for the output tensor. Note that all sparse kernels /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), /// the output buffer is already initialized to all zeroes and only nonzeroes @@ -484,10 +468,8 @@ // materializes into the computation, we need to preserve the zero // initialization assumption of all sparse output buffers. if (isMaterializing(tensor)) { - Type tp = denseTp.getElementType(); Value alloc = rewriter.create(loc, denseTp, args); - Value zero = - rewriter.create(loc, tp, rewriter.getZeroAttr(tp)); + Value zero = constantZero(rewriter, loc, denseTp.getElementType()); rewriter.create(loc, zero, alloc); return alloc; } @@ -522,11 +504,11 @@ // Handle sparse storage schemes. if (merger.isDim(tensor, idx, Dim::kSparse)) { auto dynShape = {ShapedType::kDynamicSize}; - auto ptrTp = MemRefType::get( - dynShape, genIntType(rewriter, enc.getPointerBitWidth())); - auto indTp = MemRefType::get( - dynShape, genIntType(rewriter, enc.getIndexBitWidth())); - Value dim = rewriter.create(loc, d); + auto ptrTp = + MemRefType::get(dynShape, getPointerOverheadType(rewriter, enc)); + auto indTp = + MemRefType::get(dynShape, getIndexOverheadType(rewriter, enc)); + Value dim = constantIndex(rewriter, loc, d); // Generate sparse primitives to obtains pointer and indices. codegen.pointers[tensor][idx] = rewriter.create(loc, ptrTp, t->get(), dim); @@ -557,7 +539,7 @@ genOutputBuffer(codegen, rewriter, op, denseTp, args); } else if (t == codegen.sparseOut) { // True sparse output needs a lexIdx array. - Value rank = rewriter.create(loc, op.getRank(t)); + Value rank = constantIndex(rewriter, loc, op.getRank(t)); auto dynShape = {ShapedType::kDynamicSize}; auto memTp = MemRefType::get(dynShape, rewriter.getIndexType()); codegen.lexIdx = rewriter.create(loc, memTp, rank); @@ -585,7 +567,7 @@ static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, Value iv, Value lo, Value hi, Value step) { Location loc = iv.getLoc(); - VectorType mtp = vectorType(codegen, genIntType(rewriter, 1)); + VectorType mtp = vectorType(codegen, rewriter.getI1Type()); // Special case if the vector length evenly divides the trip count (for // example, "for i = 0, 128, 16"). A constant all-true mask is generated // so that all subsequent masked memory operations are immediately folded @@ -596,7 +578,7 @@ matchPattern(step, m_Constant(&stepInt))) { if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) return rewriter.create( - loc, mtp, rewriter.create(loc, 1, 1)); + loc, mtp, constantI1(rewriter, loc, true)); } // Otherwise, generate a vector mask that avoids overrunning the upperbound // during vector execution. Here we rely on subsequent loop optimizations to @@ -617,12 +599,11 @@ Value ptr, ArrayRef args) { Location loc = ptr.getLoc(); VectorType vtp = vectorType(codegen, ptr); - Value pass = - rewriter.create(loc, vtp, rewriter.getZeroAttr(vtp)); + Value pass = constantZero(rewriter, loc, vtp); if (args.back().getType().isa()) { SmallVector scalarArgs(args.begin(), args.end()); Value indexVec = args.back(); - scalarArgs.back() = rewriter.create(loc, 0); + scalarArgs.back() = constantIndex(rewriter, loc, 0); return rewriter.create( loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); } @@ -637,7 +618,7 @@ if (args.back().getType().isa()) { SmallVector scalarArgs(args.begin(), args.end()); Value indexVec = args.back(); - scalarArgs.back() = rewriter.create(loc, 0); + scalarArgs.back() = constantIndex(rewriter, loc, 0); rewriter.create(loc, ptr, scalarArgs, indexVec, codegen.curVecMask, rhs); return; @@ -679,7 +660,7 @@ } case AffineExprKind::Constant: { int64_t c = a.cast().getValue(); - return rewriter.create(loc, c); + return constantIndex(rewriter, loc, c); } default: llvm_unreachable("unexpected affine subscript"); @@ -728,8 +709,7 @@ // Direct lexicographic index order, tensor loads as zero. if (!codegen.expValues) { Type tp = getElementTypeOrSelf(t->get().getType()); - return rewriter.create(loc, tp, - rewriter.getZeroAttr(tp)); + return constantZero(rewriter, loc, tp); } // Load from expanded access pattern. Value index = genIndex(codegen, op, t); @@ -752,8 +732,8 @@ // endif // values[i] = rhs Value index = genIndex(codegen, op, t); - Value fval = rewriter.create(loc, 0, 1); // false - Value tval = rewriter.create(loc, 1, 1); // true + Value fval = constantI1(rewriter, loc, false); + Value tval = constantI1(rewriter, loc, true); // If statement. Value filled = rewriter.create(loc, codegen.expFilled, index); Value cond = rewriter.create(loc, arith::CmpIPredicate::eq, @@ -765,7 +745,7 @@ rewriter.create(loc, tval, codegen.expFilled, index); rewriter.create(loc, index, codegen.expAdded, codegen.expCount); - Value one = rewriter.create(loc, 1); + Value one = constantIndex(rewriter, loc, 1); Value add = rewriter.create(loc, codegen.expCount, one); rewriter.create(loc, add); // False branch. @@ -852,11 +832,11 @@ if (!etp.isa()) { if (etp.getIntOrFloatBitWidth() < 32) vload = rewriter.create( - loc, vload, vectorType(codegen, genIntType(rewriter, 32))); + loc, vload, vectorType(codegen, rewriter.getI32Type())); else if (etp.getIntOrFloatBitWidth() < 64 && !codegen.options.enableSIMDIndex32) vload = rewriter.create( - loc, vload, vectorType(codegen, genIntType(rewriter, 64))); + loc, vload, vectorType(codegen, rewriter.getI64Type())); } return vload; } @@ -867,8 +847,7 @@ Value load = rewriter.create(loc, ptr, s); if (!load.getType().isa()) { if (load.getType().getIntOrFloatBitWidth() < 64) - load = - rewriter.create(loc, load, genIntType(rewriter, 64)); + load = rewriter.create(loc, load, rewriter.getI64Type()); load = rewriter.create(loc, load, rewriter.getIndexType()); } @@ -1000,8 +979,8 @@ auto dynShape = {ShapedType::kDynamicSize}; Type etp = tensor.getType().cast().getElementType(); Type t1 = MemRefType::get(dynShape, etp); - Type t2 = MemRefType::get(dynShape, genIntType(rewriter, 1)); - Type t3 = MemRefType::get(dynShape, genIntType(rewriter, 0)); + Type t2 = MemRefType::get(dynShape, rewriter.getI1Type()); + Type t3 = MemRefType::get(dynShape, rewriter.getIndexType()); Type t4 = rewriter.getIndexType(); auto res = rewriter.create(loc, TypeRange({t1, t2, t3, t4}), tensor); @@ -1044,8 +1023,8 @@ break; } Value ptr = codegen.pointers[tensor][idx]; - Value one = rewriter.create(loc, 1); - Value p0 = (pat == 0) ? rewriter.create(loc, 0) + Value one = constantIndex(rewriter, loc, 1); + Value p0 = (pat == 0) ? constantIndex(rewriter, loc, 0) : codegen.pidxs[tensor][topSort[pat - 1]]; codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); Value p1 = rewriter.create(loc, p0, one); @@ -1058,7 +1037,7 @@ } // Initialize the universal dense index. - codegen.loops[idx] = rewriter.create(loc, 0); + codegen.loops[idx] = constantIndex(rewriter, loc, 0); return needsUniv; } @@ -1148,8 +1127,7 @@ Location loc = op.getLoc(); Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; - Value step = - rewriter.create(loc, codegen.curVecLength); + Value step = constantIndex(rewriter, loc, codegen.curVecLength); // Emit a parallel loop. if (isParallel) { @@ -1323,7 +1301,7 @@ for (; pat != 0; pat--) if (codegen.pidxs[tensor][topSort[pat - 1]]) break; - Value p = (pat == 0) ? rewriter.create(loc, 0) + Value p = (pat == 0) ? constantIndex(rewriter, loc, 0) : codegen.pidxs[tensor][topSort[pat - 1]]; codegen.pidxs[tensor][idx] = genAddress( codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); @@ -1333,7 +1311,7 @@ // Move the insertion indices in lexicographic index order. During access // pattern expansion, we can skip setting the innermost dimension. if (codegen.sparseOut && !codegen.expValues) { - Value pos = rewriter.create(loc, at); + Value pos = constantIndex(rewriter, loc, at); rewriter.create(loc, codegen.loops[idx], codegen.lexIdx, pos); } @@ -1373,7 +1351,7 @@ // after the if-statements more closely resembles code generated by TACO. unsigned o = 0; SmallVector operands; - Value one = rewriter.create(loc, 1); + Value one = constantIndex(rewriter, loc, 1); for (unsigned b = 0, be = induction.size(); b < be; b++) { if (induction[b] && merger.isDim(b, Dim::kSparse)) { unsigned tensor = merger.tensor(b); @@ -1445,7 +1423,7 @@ clause = rewriter.create(loc, arith::CmpIPredicate::eq, op1, op2); } else { - clause = rewriter.create(loc, 1, 1); // true + clause = constantI1(rewriter, loc, true); } cond = cond ? rewriter.create(loc, cond, clause) : clause; } diff --git a/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRSparseTensorUtils Merger.cpp + CodegenUtils.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor diff --git a/mlir/lib/Dialect/SparseTensor/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Utils/CodegenUtils.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Utils/CodegenUtils.h @@ -0,0 +1,166 @@ +//===- CodegenUtils.h - Utilities for generating MLIR -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines utilities for generating MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_CODEGENUTILS_H_ +#define MLIR_DIALECT_SPARSETENSOR_UTILS_CODEGENUTILS_H_ + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/ExecutionEngine/SparseTensorUtils.h" +#include "mlir/IR/Builders.h" + +namespace mlir { +class Location; +class Type; +class Value; + +namespace sparse_tensor { + +//===----------------------------------------------------------------------===// +// ExecutionEngine/SparseTensorUtils helper functions. +//===----------------------------------------------------------------------===// + +/// Converts an overhead storage bitwidth to its internal type-encoding. +OverheadType overheadTypeEncoding(unsigned width); + +/// Converts the internal type-encoding for overhead storage to an mlir::Type. +Type getOverheadType(Builder &builder, OverheadType ot); + +/// Returns the mlir::Type for pointer overhead storage. +Type getPointerOverheadType(Builder &builder, + const SparseTensorEncodingAttr &enc); + +/// Returns the mlir::Type for index overhead storage. +Type getIndexOverheadType(Builder &builder, + const SparseTensorEncodingAttr &enc); + +/// Converts a primary storage type to its internal type-encoding. +PrimaryType primaryTypeEncoding(Type elemTp); + +/// Converts the IR's dimension level type to its internal type-encoding. +DimLevelType dimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt); + +//===----------------------------------------------------------------------===// +// Misc code generators. +// +// TODO: both of these should move upstream to their respective classes. +// Once RFCs have been created for those changes, list them here. +//===----------------------------------------------------------------------===// + +/// Generates a 1-valued attribute of the given type. This supports +/// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`, +/// for unsupported types we raise `llvm_unreachable` rather than +/// returning a null attribute. +Attribute getOneAttr(Builder &builder, Type tp); + +/// Generates the comparison `v != 0` where `v` is of numeric type. +/// For floating types, we use the "unordered" comparator (i.e., returns +/// true if `v` is NaN). +Value genIsNonzero(OpBuilder &builder, Location loc, Value v); + +//===----------------------------------------------------------------------===// +// Constant generators. +// +// All these functions are just wrappers to improve code legibility; +// therefore, we mark them as `inline` to avoid introducing any additional +// overhead due to the legibility. +// +// TODO: Ideally these should move upstream, so that we don't +// develop a design island. However, doing so will involve +// substantial design work. For related prior discussion, see +// +//===----------------------------------------------------------------------===// + +/// Generates a 0-valued constant of the given type. In addition to +/// the scalar types (`FloatType`, `IndexType`, `IntegerType`), this also +/// works for `RankedTensorType` and `VectorType` (for which it generates +/// a constant `DenseElementsAttr` of zeros). +inline Value constantZero(OpBuilder &builder, Location loc, Type tp) { + return builder.create(loc, tp, builder.getZeroAttr(tp)); +} + +/// Generates a 1-valued constant of the given type. This supports all +/// the same types as `constantZero`. +inline Value constantOne(OpBuilder &builder, Location loc, Type tp) { + return builder.create(loc, tp, getOneAttr(builder, tp)); +} + +/// Generates a constant of `index` type. +inline Value constantIndex(OpBuilder &builder, Location loc, int64_t i) { + return builder.create(loc, i); +} + +/// Generates a constant of `i32` type. +inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) { + return builder.create(loc, i, 32); +} + +/// Generates a constant of `i16` type. +inline Value constantI16(OpBuilder &builder, Location loc, int16_t i) { + return builder.create(loc, i, 16); +} + +/// Generates a constant of `i8` type. +inline Value constantI8(OpBuilder &builder, Location loc, int8_t i) { + return builder.create(loc, i, 8); +} + +/// Generates a constant of `i1` type. +inline Value constantI1(OpBuilder &builder, Location loc, bool b) { + return builder.create(loc, b, 1); +} + +/// Generates a constant of the given `Action`. +inline Value constantAction(OpBuilder &builder, Location loc, Action action) { + return constantI32(builder, loc, static_cast(action)); +} + +/// Generates a constant of the internal type-encoding for overhead storage. +inline Value constantOverheadTypeEncoding(OpBuilder &builder, Location loc, + unsigned width) { + return constantI32(builder, loc, + static_cast(overheadTypeEncoding(width))); +} + +/// Generates a constant of the internal type-encoding for pointer +/// overhead storage. +inline Value constantPointerTypeEncoding(OpBuilder &builder, Location loc, + const SparseTensorEncodingAttr &enc) { + return constantOverheadTypeEncoding(builder, loc, enc.getPointerBitWidth()); +} + +/// Generates a constant of the internal type-encoding for index overhead +/// storage. +inline Value constantIndexTypeEncoding(OpBuilder &builder, Location loc, + const SparseTensorEncodingAttr &enc) { + return constantOverheadTypeEncoding(builder, loc, enc.getIndexBitWidth()); +} + +/// Generates a constant of the internal type-encoding for primary storage. +inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc, + Type elemTp) { + return constantI32(builder, loc, + static_cast(primaryTypeEncoding(elemTp))); +} + +/// Generates a constant of the internal dimension level type encoding. +inline Value +constantDimLevelTypeEncoding(OpBuilder &builder, Location loc, + SparseTensorEncodingAttr::DimLevelType dlt) { + return constantI8(builder, loc, + static_cast(dimLevelTypeEncoding(dlt))); +} + +} // namespace sparse_tensor +} // namespace mlir + +#endif // MLIR_DIALECT_SPARSETENSOR_UTILS_CODEGENUTILS_H_ diff --git a/mlir/lib/Dialect/SparseTensor/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Utils/CodegenUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Utils/CodegenUtils.cpp @@ -0,0 +1,128 @@ +//===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "CodegenUtils.h" + +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" + +using namespace mlir::sparse_tensor; + +//===----------------------------------------------------------------------===// +// ExecutionEngine/SparseTensorUtils helper functions. +//===----------------------------------------------------------------------===// + +OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) { + switch (width) { + default: + return OverheadType::kU64; + case 32: + return OverheadType::kU32; + case 16: + return OverheadType::kU16; + case 8: + return OverheadType::kU8; + } +} + +mlir::Type mlir::sparse_tensor::getOverheadType(mlir::Builder &builder, + OverheadType ot) { + switch (ot) { + case OverheadType::kU64: + return builder.getIntegerType(64); + case OverheadType::kU32: + return builder.getIntegerType(32); + case OverheadType::kU16: + return builder.getIntegerType(16); + case OverheadType::kU8: + return builder.getIntegerType(8); + } + llvm_unreachable("Unknown OverheadType"); +} + +mlir::Type mlir::sparse_tensor::getPointerOverheadType( + mlir::Builder &builder, const SparseTensorEncodingAttr &enc) { + // NOTE(wrengr): This workaround will be fixed in D115010. + unsigned width = enc.getPointerBitWidth(); + if (width == 0) + return builder.getIndexType(); + return getOverheadType(builder, overheadTypeEncoding(width)); +} + +mlir::Type +mlir::sparse_tensor::getIndexOverheadType(mlir::Builder &builder, + const SparseTensorEncodingAttr &enc) { + // NOTE(wrengr): This workaround will be fixed in D115010. + unsigned width = enc.getIndexBitWidth(); + if (width == 0) + return builder.getIndexType(); + return getOverheadType(builder, overheadTypeEncoding(width)); +} + +PrimaryType mlir::sparse_tensor::primaryTypeEncoding(mlir::Type elemTp) { + if (elemTp.isF64()) + return PrimaryType::kF64; + if (elemTp.isF32()) + return PrimaryType::kF32; + if (elemTp.isInteger(64)) + return PrimaryType::kI64; + if (elemTp.isInteger(32)) + return PrimaryType::kI32; + if (elemTp.isInteger(16)) + return PrimaryType::kI16; + if (elemTp.isInteger(8)) + return PrimaryType::kI8; + llvm_unreachable("Unknown primary type"); +} + +DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding( + SparseTensorEncodingAttr::DimLevelType dlt) { + switch (dlt) { + case SparseTensorEncodingAttr::DimLevelType::Dense: + return DimLevelType::kDense; + case SparseTensorEncodingAttr::DimLevelType::Compressed: + return DimLevelType::kCompressed; + case SparseTensorEncodingAttr::DimLevelType::Singleton: + return DimLevelType::kSingleton; + } + llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); +} + +//===----------------------------------------------------------------------===// +// Misc code generators. +//===----------------------------------------------------------------------===// + +mlir::Attribute mlir::sparse_tensor::getOneAttr(mlir::Builder &builder, + mlir::Type tp) { + if (tp.isa()) + return builder.getFloatAttr(tp, 1.0); + if (tp.isa()) + return builder.getIndexAttr(1); + if (auto intTp = tp.dyn_cast()) + return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1)); + if (tp.isa()) { + auto shapedTp = tp.cast(); + if (auto one = getOneAttr(builder, shapedTp.getElementType())) + return DenseElementsAttr::get(shapedTp, one); + } + llvm_unreachable("Unsupported attribute type"); +} + +mlir::Value mlir::sparse_tensor::genIsNonzero(mlir::OpBuilder &builder, + mlir::Location loc, + mlir::Value v) { + mlir::Type tp = v.getType(); + mlir::Value zero = constantZero(builder, loc, tp); + if (tp.isa()) + return builder.create(loc, arith::CmpFPredicate::UNE, v, + zero); + if (tp.isIntOrIndex()) + return builder.create(loc, arith::CmpIPredicate::ne, v, + zero); + llvm_unreachable("Non-numeric type"); +} diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/SparseTensor/Utils/Merger.h" +#include "CodegenUtils.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/IR/Operation.h" @@ -666,10 +667,7 @@ return rewriter.create(loc, v0); case kNegI: // no negi in std return rewriter.create( - loc, - rewriter.create(loc, v0.getType(), - rewriter.getZeroAttr(v0.getType())), - v0); + loc, constantZero(rewriter, loc, v0.getType()), v0); case kTruncF: return rewriter.create(loc, v0, inferType(e, v0)); case kExtF: diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1797,7 +1797,6 @@ ":SideEffectInterfaces", ":SparseTensorAttrDefsIncGen", ":SparseTensorOpsIncGen", - ":SparseTensorUtils", ":StandardOps", "//llvm:Support", ], @@ -1805,24 +1804,32 @@ cc_library( name = "SparseTensorUtils", - srcs = glob(["lib/Dialect/SparseTensor/Utils/*.cpp"]), - hdrs = glob(["include/mlir/Dialect/SparseTensor/Utils/*.h"]), + srcs = glob([ + "lib/Dialect/SparseTensor/Utils/*.cpp", + "lib/Dialect/SparseTensor/Utils/*.h", + ]), + hdrs = glob([ + "include/mlir/Dialect/SparseTensor/Utils/*.h", + ]) + [ + "include/mlir/ExecutionEngine/SparseTensorUtils.h", + ], includes = ["include"], deps = [ ":ArithmeticDialect", ":IR", ":LinalgOps", - ":SideEffectInterfaces", - ":SparseTensorAttrDefsIncGen", - ":SparseTensorOpsIncGen", - ":StandardOps", + ":SparseTensor", "//llvm:Support", ], ) cc_library( name = "SparseTensorTransforms", - srcs = glob(["lib/Dialect/SparseTensor/Transforms/*.cpp"]), + srcs = glob([ + "lib/Dialect/SparseTensor/Transforms/*.cpp", + ]) + [ + "lib/Dialect/SparseTensor/Utils/CodegenUtils.h", + ], hdrs = [ "include/mlir/Dialect/SparseTensor/Transforms/Passes.h", "include/mlir/ExecutionEngine/SparseTensorUtils.h",