diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -34,6 +34,16 @@ kFloorF, kNegF, kNegI, + kTruncF, + kExtF, + kCastFS, // signed + kCastFU, // unsigned + kCastSF, // signed + kCastUF, // unsigned + kCastS, // signed + kCastU, // unsigned + kTruncI, + kBitCast, // Binary operations. kMulF, kMulI, @@ -73,8 +83,9 @@ Children children; }; - /// Direct link to IR for an invariant. During code generation, - /// field is used to cache "hoisted" loop invariant tensor loads. + /// Direct link to IR for an invariant or the destination value (to + /// infer destination type) of a cast operation During code generation, + /// this field may be used to cache "hoisted" loop invariant tensor loads. Value val; }; @@ -115,6 +126,7 @@ /// Adds a tensor expression. Returns its index. unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()); + unsigned addExp(Kind k, unsigned e, Value v) { return addExp(k, e, -1u, v); } unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); } /// Adds an iteration lattice point. Returns its index. @@ -140,7 +152,7 @@ /// Maps the unary operator over the lattice set of the operand, i.e. each /// lattice point on an expression E is simply copied over, but with OP E /// as new expression. Returns the index of the new set. - unsigned mapSet(Kind kind, unsigned s0); + unsigned mapSet(Kind kind, unsigned s0, Value v = Value()); /// Optimizes the iteration lattice points in the given set. This /// method should be called right before code generation to avoid @@ -220,6 +232,7 @@ private: bool maybeZero(unsigned e) const; bool isInvariant(unsigned e) const; + Type inferType(unsigned e, Value src); /// Traverses the SSA tree (possibly a DAG) to build a tensor expression. Optional buildTensorExp(linalg::GenericOp op, Value v); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -14,9 +14,9 @@ namespace mlir { namespace sparse_tensor { -// +//===----------------------------------------------------------------------===// // Constructors. -// +//===----------------------------------------------------------------------===// TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v) : kind(k), val(v) { @@ -37,6 +37,20 @@ children.e0 = x; children.e1 = y; break; + case kTruncF: + case kExtF: + case kCastFS: + case kCastFU: + case kCastSF: + case kCastUF: + case kCastS: + case kCastU: + case kTruncI: + case kBitCast: + assert(x != -1u && y == -1u && v); + children.e0 = x; + children.e1 = y; + break; default: assert(x != -1u && y != -1u && !v); children.e0 = x; @@ -53,9 +67,9 @@ LatPoint::LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), simple(), exp(e) {} -// +//===----------------------------------------------------------------------===// // Lattice methods. -// +//===----------------------------------------------------------------------===// unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) { unsigned e = tensorExps.size(); @@ -109,11 +123,11 @@ return s; } -unsigned Merger::mapSet(Kind kind, unsigned s0) { - assert(kAbsF <= kind && kind <= kNegI); +unsigned Merger::mapSet(Kind kind, unsigned s0, Value v) { + assert(kAbsF <= kind && kind <= kBitCast); unsigned s = addSet(); for (unsigned p : latSets[s0]) { - unsigned e = addExp(kind, latPoints[p].exp); + unsigned e = addExp(kind, latPoints[p].exp, v); latPoints.push_back(LatPoint(latPoints[p].bits, e)); latSets[s].push_back(latPoints.size() - 1); } @@ -207,6 +221,16 @@ case kFloorF: case kNegF: case kNegI: + case kTruncF: + case kExtF: + case kCastFS: + case kCastFU: + case kCastSF: + case kCastUF: + case kCastS: + case kCastU: + case kTruncI: + case kBitCast: return isConjunction(t, tensorExps[e].children.e0); case kDivF: // note: x / c only case kDivS: @@ -230,9 +254,9 @@ #ifndef NDEBUG -// +//===----------------------------------------------------------------------===// // Print methods (for debugging). -// +//===----------------------------------------------------------------------===// static const char *kindToOpSymbol(Kind kind) { switch (kind) { @@ -250,6 +274,17 @@ return "-"; case kNegI: return "-"; + case kTruncF: + case kExtF: + case kCastFS: + case kCastFU: + case kCastSF: + case kCastUF: + case kCastS: + case kCastU: + case kTruncI: + case kBitCast: + return "cast"; case kMulF: return "*"; case kMulI: @@ -301,6 +336,16 @@ case kFloorF: case kNegF: case kNegI: + case kTruncF: + case kExtF: + case kCastFS: + case kCastFU: + case kCastSF: + case kCastUF: + case kCastS: + case kCastU: + case kTruncI: + case kBitCast: llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; dumpExp(tensorExps[e].children.e0); break; @@ -358,9 +403,9 @@ #endif // NDEBUG -// +//===----------------------------------------------------------------------===// // Builder methods. -// +//===----------------------------------------------------------------------===// unsigned Merger::buildLattices(unsigned e, unsigned i) { Kind kind = tensorExps[e].kind; @@ -380,13 +425,24 @@ case kFloorF: case kNegF: case kNegI: + case kTruncF: + case kExtF: + case kCastFS: + case kCastFU: + case kCastSF: + case kCastUF: + case kCastS: + case kCastU: + case kTruncI: + case kBitCast: // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the // lattice set of the operand through the operator into a new set. // // -y|!y | y | // --+---+---+ // | 0 |-y | - return mapSet(kind, buildLattices(tensorExps[e].children.e0, i)); + return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), + tensorExps[e].val); case kMulF: case kMulI: case kAndI: @@ -469,6 +525,16 @@ return tensorExps[e].kind == kInvariant; } +Type Merger::inferType(unsigned e, Value src) { + // Obtain the destination type from the cast node. + Type dtp = tensorExps[e].val.getType(); + // Inspect source type. For vector types, apply the same + // vectorization to the destination type. + if (auto vtp = src.getType().dyn_cast()) + return VectorType::get(vtp.getNumElements(), dtp); + return dtp; +} + Optional Merger::buildTensorExp(linalg::GenericOp op, Value v) { if (auto arg = v.dyn_cast()) { unsigned argN = arg.getArgNumber(); @@ -501,12 +567,32 @@ if (isa(def)) return addExp(kFloorF, e); if (isa(def)) - return addExp(kNegF, e); - // TODO: no negi in std? + return addExp(kNegF, e); // TODO: no negi in std? + if (isa(def)) + return addExp(kTruncF, e, v); + if (isa(def)) + return addExp(kExtF, e, v); + if (isa(def)) + return addExp(kCastFS, e, v); + if (isa(def)) + return addExp(kCastFU, e, v); + if (isa(def)) + return addExp(kCastSF, e, v); + if (isa(def)) + return addExp(kCastUF, e, v); + if (isa(def)) + return addExp(kCastS, e, v); + if (isa(def)) + return addExp(kCastU, e, v); + if (isa(def)) + return addExp(kTruncI, e, v); + if (isa(def)) + return addExp(kBitCast, e, v); } } // Construct binary operations if subexpressions can be built. - // TODO: see buildLattices() for an explanation of rejecting certain divisions + // TODO: see buildLattices() for an explanation of rejecting + // certain division and shift operations if (def->getNumOperands() == 2) { auto x = buildTensorExp(op, def->getOperand(0)); auto y = buildTensorExp(op, def->getOperand(1)); @@ -555,6 +641,7 @@ case kTensor: case kInvariant: llvm_unreachable("unexpected non-op"); + // Unary ops. case kAbsF: return rewriter.create(loc, v0); case kCeilF: @@ -566,6 +653,27 @@ case kNegI: assert(v1); // no negi in std return rewriter.create(loc, v0, v1); + case kTruncF: + return rewriter.create(loc, v0, inferType(e, v0)); + case kExtF: + return rewriter.create(loc, v0, inferType(e, v0)); + case kCastFS: + return rewriter.create(loc, v0, inferType(e, v0)); + case kCastFU: + return rewriter.create(loc, v0, inferType(e, v0)); + case kCastSF: + return rewriter.create(loc, v0, inferType(e, v0)); + case kCastUF: + return rewriter.create(loc, v0, inferType(e, v0)); + case kCastS: + return rewriter.create(loc, v0, inferType(e, v0)); + case kCastU: + return rewriter.create(loc, v0, inferType(e, v0)); + case kTruncI: + return rewriter.create(loc, v0, inferType(e, v0)); + case kBitCast: + return rewriter.create(loc, v0, inferType(e, v0)); + // Binary ops. case kMulF: return rewriter.create(loc, v0, v1); case kMulI: diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_cast.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_cast.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_cast.mlir @@ -0,0 +1,277 @@ +// RUN: mlir-opt %s \ +// RUN: --sparsification --sparse-tensor-conversion \ +// RUN: --convert-vector-to-scf --convert-scf-to-std \ +// RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \ +// RUN: --std-bufferize --finalizing-bufferize --lower-affine \ +// RUN: --convert-vector-to-llvm --convert-memref-to-llvm --convert-std-to-llvm | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s +// +// Do the same run, but now with SIMDization as well. This should not change the outcome. +// +// RUN: mlir-opt %s \ +// RUN: --sparsification="vectorization-strategy=2 vl=2 enable-simd-index32" --sparse-tensor-conversion \ +// RUN: --convert-vector-to-scf --convert-scf-to-std \ +// RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \ +// RUN: --std-bufferize --finalizing-bufferize --lower-affine \ +// RUN: --convert-vector-to-llvm --convert-memref-to-llvm --convert-std-to-llvm | \ +// RUN: TENSOR0="%mlir_integration_test_dir/data/test.mtx" \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s +// + +#SV = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> + +#trait_cast = { + indexing_maps = [ + affine_map<(i) -> (i)>, // A (in) + affine_map<(i) -> (i)> // X (out) + ], + iterator_types = ["parallel"], + doc = "X(i) = cast A(i)" +} + +// +// Integration test that lowers a kernel annotated as sparse to actual sparse +// code, initializes a matching sparse storage scheme from a dense vector, +// and runs the resulting code with the JIT compiler. +// +module { + // + // Various kernels that cast a sparse vector from one type to another. + // Standard supports the following casts. + // sitofp + // uitofp + // fptosi + // fptoui + // fpext + // fptrunc + // sexti + // zexti + // trunci + // bitcast + // Since all casts are "zero preserving" unary operations, lattice computation + // and conversion to sparse code is straightforward. + // + func @sparse_cast_s32_to_f32(%arga: tensor<10xi32, #SV>) -> tensor<10xf32> { + %argx = constant dense<0.0> : tensor<10xf32> + %0 = linalg.generic #trait_cast + ins(%arga: tensor<10xi32, #SV>) + outs(%argx: tensor<10xf32>) { + ^bb(%a: i32, %x : f32): + %cst = sitofp %a : i32 to f32 + linalg.yield %cst : f32 + } -> tensor<10xf32> + return %0 : tensor<10xf32> + } + func @sparse_cast_u32_to_f32(%arga: tensor<10xi32, #SV>) -> tensor<10xf32> { + %argx = constant dense<0.0> : tensor<10xf32> + %0 = linalg.generic #trait_cast + ins(%arga: tensor<10xi32, #SV>) + outs(%argx: tensor<10xf32>) { + ^bb(%a: i32, %x : f32): + %cst = uitofp %a : i32 to f32 + linalg.yield %cst : f32 + } -> tensor<10xf32> + return %0 : tensor<10xf32> + } + func @sparse_cast_f32_to_s32(%arga: tensor<10xf32, #SV>) -> tensor<10xi32> { + %argx = constant dense<0> : tensor<10xi32> + %0 = linalg.generic #trait_cast + ins(%arga: tensor<10xf32, #SV>) + outs(%argx: tensor<10xi32>) { + ^bb(%a: f32, %x : i32): + %cst = fptosi %a : f32 to i32 + linalg.yield %cst : i32 + } -> tensor<10xi32> + return %0 : tensor<10xi32> + } + func @sparse_cast_f64_to_u32(%arga: tensor<10xf64, #SV>) -> tensor<10xi32> { + %argx = constant dense<0> : tensor<10xi32> + %0 = linalg.generic #trait_cast + ins(%arga: tensor<10xf64, #SV>) + outs(%argx: tensor<10xi32>) { + ^bb(%a: f64, %x : i32): + %cst = fptoui %a : f64 to i32 + linalg.yield %cst : i32 + } -> tensor<10xi32> + return %0 : tensor<10xi32> + } + func @sparse_cast_f32_to_f64(%arga: tensor<10xf32, #SV>) -> tensor<10xf64> { + %argx = constant dense<0.0> : tensor<10xf64> + %0 = linalg.generic #trait_cast + ins(%arga: tensor<10xf32, #SV>) + outs(%argx: tensor<10xf64>) { + ^bb(%a: f32, %x : f64): + %cst = fpext %a : f32 to f64 + linalg.yield %cst : f64 + } -> tensor<10xf64> + return %0 : tensor<10xf64> + } + func @sparse_cast_f64_to_f32(%arga: tensor<10xf64, #SV>) -> tensor<10xf32> { + %argx = constant dense<0.0> : tensor<10xf32> + %0 = linalg.generic #trait_cast + ins(%arga: tensor<10xf64, #SV>) + outs(%argx: tensor<10xf32>) { + ^bb(%a: f64, %x : f32): + %cst = fptrunc %a : f64 to f32 + linalg.yield %cst : f32 + } -> tensor<10xf32> + return %0 : tensor<10xf32> + } + func @sparse_cast_s32_to_u64(%arga: tensor<10xi32, #SV>) -> tensor<10xi64> { + %argx = constant dense<0> : tensor<10xi64> + %0 = linalg.generic #trait_cast + ins(%arga: tensor<10xi32, #SV>) + outs(%argx: tensor<10xi64>) { + ^bb(%a: i32, %x : i64): + %cst = sexti %a : i32 to i64 + linalg.yield %cst : i64 + } -> tensor<10xi64> + return %0 : tensor<10xi64> + } + func @sparse_cast_u32_to_s64(%arga: tensor<10xi32, #SV>) -> tensor<10xi64> { + %argx = constant dense<0> : tensor<10xi64> + %0 = linalg.generic #trait_cast + ins(%arga: tensor<10xi32, #SV>) + outs(%argx: tensor<10xi64>) { + ^bb(%a: i32, %x : i64): + %cst = zexti %a : i32 to i64 + linalg.yield %cst : i64 + } -> tensor<10xi64> + return %0 : tensor<10xi64> + } + func @sparse_cast_i32_to_i8(%arga: tensor<10xi32, #SV>) -> tensor<10xi8> { + %argx = constant dense<0> : tensor<10xi8> + %0 = linalg.generic #trait_cast + ins(%arga: tensor<10xi32, #SV>) + outs(%argx: tensor<10xi8>) { + ^bb(%a: i32, %x : i8): + %cst = trunci %a : i32 to i8 + linalg.yield %cst : i8 + } -> tensor<10xi8> + return %0 : tensor<10xi8> + } + func @sparse_cast_f32_as_s32(%arga: tensor<10xf32, #SV>) -> tensor<10xi32> { + %argx = constant dense<0> : tensor<10xi32> + %0 = linalg.generic #trait_cast + ins(%arga: tensor<10xf32, #SV>) + outs(%argx: tensor<10xi32>) { + ^bb(%a: f32, %x : i32): + %cst = bitcast %a : f32 to i32 + linalg.yield %cst : i32 + } -> tensor<10xi32> + return %0 : tensor<10xi32> + } + + // + // Main driver that converts a dense tensor into a sparse tensor + // and then calls the sparse casting kernel. + // + func @entry() { + %z = constant 0 : index + %b = constant 0 : i8 + %i = constant 0 : i32 + %l = constant 0 : i64 + %f = constant 0.0 : f32 + %d = constant 0.0 : f64 + + // Initialize dense tensors, convert to a sparse vectors. + %0 = constant dense<[ -4, -3, -2, -1, 0, 1, 2, 3, 4, 305 ]> : tensor<10xi32> + %1 = sparse_tensor.convert %0 : tensor<10xi32> to tensor<10xi32, #SV> + %2 = constant dense<[ -4.4, -3.3, -2.2, -1.1, 0.0, 1.1, 2.2, 3.3, 4.4, 305.5 ]> : tensor<10xf32> + %3 = sparse_tensor.convert %2 : tensor<10xf32> to tensor<10xf32, #SV> + %4 = constant dense<[ -4.4, -3.3, -2.2, -1.1, 0.0, 1.1, 2.2, 3.3, 4.4, 305.5 ]> : tensor<10xf64> + %5 = sparse_tensor.convert %4 : tensor<10xf64> to tensor<10xf64, #SV> + %6 = constant dense<[ 4294967295.0, 4294967294.0, 4294967293.0, 4294967292.0, + 0.0, 1.1, 2.2, 3.3, 4.4, 305.5 ]> : tensor<10xf64> + %7 = sparse_tensor.convert %6 : tensor<10xf64> to tensor<10xf64, #SV> + + // + // CHECK: ( -4, -3, -2, -1, 0, 1, 2, 3, 4, 305 ) + // + %c0 = call @sparse_cast_s32_to_f32(%1) : (tensor<10xi32, #SV>) -> tensor<10xf32> + %m0 = memref.buffer_cast %c0 : memref<10xf32> + %v0 = vector.transfer_read %m0[%z], %f: memref<10xf32>, vector<10xf32> + vector.print %v0 : vector<10xf32> + + // + // CHECK: ( 4.29497e+09, 4.29497e+09, 4.29497e+09, 4.29497e+09, 0, 1, 2, 3, 4, 305 ) + // + %c1 = call @sparse_cast_u32_to_f32(%1) : (tensor<10xi32, #SV>) -> tensor<10xf32> + %m1 = memref.buffer_cast %c1 : memref<10xf32> + %v1 = vector.transfer_read %m1[%z], %f: memref<10xf32>, vector<10xf32> + vector.print %v1 : vector<10xf32> + + // + // CHECK: ( -4, -3, -2, -1, 0, 1, 2, 3, 4, 305 ) + // + %c2 = call @sparse_cast_f32_to_s32(%3) : (tensor<10xf32, #SV>) -> tensor<10xi32> + %m2 = memref.buffer_cast %c2 : memref<10xi32> + %v2 = vector.transfer_read %m2[%z], %i: memref<10xi32>, vector<10xi32> + vector.print %v2 : vector<10xi32> + + // + // CHECK: ( 4294967295, 4294967294, 4294967293, 4294967292, 0, 1, 2, 3, 4, 305 ) + // + %c3 = call @sparse_cast_f64_to_u32(%7) : (tensor<10xf64, #SV>) -> tensor<10xi32> + %m3 = memref.buffer_cast %c3 : memref<10xi32> + %v3 = vector.transfer_read %m3[%z], %i: memref<10xi32>, vector<10xi32> + %vu = vector.bitcast %v3 : vector<10xi32> to vector<10xui32> + vector.print %vu : vector<10xui32> + + // + // CHECK: ( -4.4, -3.3, -2.2, -1.1, 0, 1.1, 2.2, 3.3, 4.4, 305.5 ) + // + %c4 = call @sparse_cast_f32_to_f64(%3) : (tensor<10xf32, #SV>) -> tensor<10xf64> + %m4 = memref.buffer_cast %c4 : memref<10xf64> + %v4 = vector.transfer_read %m4[%z], %d: memref<10xf64>, vector<10xf64> + vector.print %v4 : vector<10xf64> + + // + // CHECK: ( -4.4, -3.3, -2.2, -1.1, 0, 1.1, 2.2, 3.3, 4.4, 305.5 ) + // + %c5 = call @sparse_cast_f64_to_f32(%5) : (tensor<10xf64, #SV>) -> tensor<10xf32> + %m5 = memref.buffer_cast %c5 : memref<10xf32> + %v5 = vector.transfer_read %m5[%z], %f: memref<10xf32>, vector<10xf32> + vector.print %v5 : vector<10xf32> + + // + // CHECK: ( -4, -3, -2, -1, 0, 1, 2, 3, 4, 305 ) + // + %c6 = call @sparse_cast_s32_to_u64(%1) : (tensor<10xi32, #SV>) -> tensor<10xi64> + %m6 = memref.buffer_cast %c6 : memref<10xi64> + %v6 = vector.transfer_read %m6[%z], %l: memref<10xi64>, vector<10xi64> + vector.print %v6 : vector<10xi64> + + // + // CHECK: ( 4294967292, 4294967293, 4294967294, 4294967295, 0, 1, 2, 3, 4, 305 ) + // + %c7 = call @sparse_cast_u32_to_s64(%1) : (tensor<10xi32, #SV>) -> tensor<10xi64> + %m7 = memref.buffer_cast %c7 : memref<10xi64> + %v7 = vector.transfer_read %m7[%z], %l: memref<10xi64>, vector<10xi64> + vector.print %v7 : vector<10xi64> + + // + // CHECK: ( -4, -3, -2, -1, 0, 1, 2, 3, 4, 49 ) + // + %c8 = call @sparse_cast_i32_to_i8(%1) : (tensor<10xi32, #SV>) -> tensor<10xi8> + %m8 = memref.buffer_cast %c8 : memref<10xi8> + %v8 = vector.transfer_read %m8[%z], %b: memref<10xi8>, vector<10xi8> + vector.print %v8 : vector<10xi8> + + // + // CHECK: ( -1064514355, -1068289229, -1072902963, -1081291571, 0, 1066192077, 1074580685, 1079194419, 1082969293, 1134084096 ) + // + %c9 = call @sparse_cast_f32_as_s32(%3) : (tensor<10xf32, #SV>) -> tensor<10xi32> + %m9 = memref.buffer_cast %c9 : memref<10xi32> + %v9 = vector.transfer_read %m9[%z], %i: memref<10xi32>, vector<10xi32> + vector.print %v9 : vector<10xi32> + + return + } +}