diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -821,18 +821,31 @@ kAnyStorageInnerLoop }; +/// Defines a type for "pointer" and "index" storage in the sparse storage +/// scheme, with a choice between the native platform-dependent index width, +/// 64-bit integers, or 32-bit integers. A narrow width obviously reduces +/// the memory footprint of the sparse storage scheme, but the width should +/// suffice to define the total required range (viz. the maximum number of +/// stored entries per indirection level for the "pointers" and the maximum +/// value of each tensor index over all dimensions for the "indices"). +enum class SparseIntType { kNative, kI64, kI32 }; + /// Sparsification options. struct SparsificationOptions { SparsificationOptions(SparseParallelizationStrategy p, - SparseVectorizationStrategy v, unsigned vl) - : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl) { - } + SparseVectorizationStrategy v, unsigned vl, + SparseIntType pt, SparseIntType it) + : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl), + ptrType(pt), indType(it) {} SparsificationOptions() : SparsificationOptions(SparseParallelizationStrategy::kNone, - SparseVectorizationStrategy::kNone, 1u) {} + SparseVectorizationStrategy::kNone, 1u, + SparseIntType::kNative, SparseIntType::kNative) {} SparseParallelizationStrategy parallelizationStrategy; SparseVectorizationStrategy vectorizationStrategy; unsigned vectorLength; + SparseIntType ptrType; + SparseIntType indType; }; /// Set up sparsification rewriting rules with the given options. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp @@ -420,16 +420,27 @@ } } +/// Maps sparse integer option to actual integral storage type. +static Type genIntType(PatternRewriter &rewriter, linalg::SparseIntType tp) { + switch (tp) { + case linalg::SparseIntType::kNative: + return rewriter.getIndexType(); + case linalg::SparseIntType::kI64: + return rewriter.getIntegerType(64); + case linalg::SparseIntType::kI32: + return rewriter.getIntegerType(32); + } +} + /// Local bufferization of all dense and sparse data structures. /// This code enables testing the first prototype sparse compiler. // TODO: replace this with a proliferated bufferization strategy -void genBuffers(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, - linalg::GenericOp op) { +static void genBuffers(Merger &merger, CodeGen &codegen, + PatternRewriter &rewriter, linalg::GenericOp op) { Location loc = op.getLoc(); unsigned numTensors = op.getNumInputsAndOutputs(); unsigned numInputs = op.getNumInputs(); assert(numTensors == numInputs + 1); - Type indexType = rewriter.getIndexType(); // For now, set all unknown dimensions to 999. // TODO: compute these values (using sparsity or by reading tensor) @@ -450,9 +461,13 @@ // Handle sparse storage schemes. if (merger.isSparseAccess(t, i)) { allDense = false; - auto dynTp = MemRefType::get({ShapedType::kDynamicSize}, indexType); - codegen.pointers[t][i] = rewriter.create(loc, dynTp, unknown); - codegen.indices[t][i] = rewriter.create(loc, dynTp, unknown); + auto dynShape = {ShapedType::kDynamicSize}; + auto ptrTp = MemRefType::get( + dynShape, genIntType(rewriter, codegen.options.ptrType)); + auto indTp = MemRefType::get( + dynShape, genIntType(rewriter, codegen.options.indType)); + codegen.pointers[t][i] = rewriter.create(loc, ptrTp, unknown); + codegen.indices[t][i] = rewriter.create(loc, indTp, unknown); } // Find lower and upper bound in current dimension. Value up; @@ -516,6 +531,15 @@ rewriter.create(op.getLoc(), rhs, codegen.buffers[tensor], args); } +/// Generates a pointer/index load from the sparse storage scheme. +static Value genIntLoad(PatternRewriter &rewriter, Location loc, Value ptr, + Value s) { + Value load = rewriter.create(loc, ptr, s); + return load.getType().isa() + ? load + : rewriter.create(loc, load, rewriter.getIndexType()); +} + /// Recursively generates tensor expression. static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, unsigned exp) { @@ -551,7 +575,6 @@ unsigned idx = topSort[at]; // Initialize sparse positions. - Value one = rewriter.create(loc, 1); for (unsigned b = 0, be = inits.size(); b < be; b++) { if (inits[b]) { unsigned tensor = merger.tensor(b); @@ -564,11 +587,12 @@ break; } Value ptr = codegen.pointers[tensor][idx]; - Value p = (pat == 0) ? rewriter.create(loc, 0) - : codegen.pidxs[tensor][topSort[pat - 1]]; - codegen.pidxs[tensor][idx] = rewriter.create(loc, ptr, p); - p = rewriter.create(loc, p, one); - codegen.highs[tensor][idx] = rewriter.create(loc, ptr, p); + Value one = rewriter.create(loc, 1); + Value p0 = (pat == 0) ? rewriter.create(loc, 0) + : codegen.pidxs[tensor][topSort[pat - 1]]; + codegen.pidxs[tensor][idx] = genIntLoad(rewriter, loc, ptr, p0); + Value p1 = rewriter.create(loc, p0, one); + codegen.highs[tensor][idx] = genIntLoad(rewriter, loc, ptr, p1); } else { // Dense index still in play. needsUniv = true; @@ -723,15 +747,17 @@ if (locals[b] && merger.isSparseBit(b)) { unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); - Value ld = rewriter.create(loc, codegen.indices[tensor][idx], - codegen.pidxs[tensor][idx]); - codegen.idxs[tensor][idx] = ld; + Value ptr = codegen.indices[tensor][idx]; + Value s = codegen.pidxs[tensor][idx]; + Value load = genIntLoad(rewriter, loc, ptr, s); + codegen.idxs[tensor][idx] = load; if (!needsUniv) { if (min) { - Value cmp = rewriter.create(loc, CmpIPredicate::ult, ld, min); - min = rewriter.create(loc, cmp, ld, min); + Value cmp = + rewriter.create(loc, CmpIPredicate::ult, load, min); + min = rewriter.create(loc, cmp, load, min); } else { - min = ld; + min = load; } } } diff --git a/mlir/test/Dialect/Linalg/sparse_storage.mlir b/mlir/test/Dialect/Linalg/sparse_storage.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/sparse_storage.mlir @@ -0,0 +1,98 @@ +// RUN: mlir-opt %s -test-sparsification="ptr-type=1 ind-type=1" | \ +// RUN: FileCheck %s --check-prefix=CHECK-TYPE0 +// RUN: mlir-opt %s -test-sparsification="ptr-type=1 ind-type=2" | \ +// RUN: FileCheck %s --check-prefix=CHECK-TYPE1 +// RUN: mlir-opt %s -test-sparsification="ptr-type=2 ind-type=1" | \ +// RUN: FileCheck %s --check-prefix=CHECK-TYPE2 +// RUN: mlir-opt %s -test-sparsification="ptr-type=2 ind-type=2" | \ +// RUN: FileCheck %s --check-prefix=CHECK-TYPE3 + +#trait_mul_1d = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a + affine_map<(i) -> (i)>, // b + affine_map<(i) -> (i)> // x (out) + ], + sparse = [ + [ "S" ], // a + [ "D" ], // b + [ "D" ] // x + ], + iterator_types = ["parallel"], + doc = "x(i) = a(i) * b(i)" +} + +// CHECK-TYPE0-LABEL: func @mul_dd( +// CHECK-TYPE0: %[[C0:.*]] = constant 0 : index +// CHECK-TYPE0: %[[C1:.*]] = constant 1 : index +// CHECK-TYPE0: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref +// CHECK-TYPE0: %[[B0:.*]] = index_cast %[[P0]] : i64 to index +// CHECK-TYPE0: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref +// CHECK-TYPE0: %[[B1:.*]] = index_cast %[[P1]] : i64 to index +// CHECK-TYPE0: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] { +// CHECK-TYPE0: %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref +// CHECK-TYPE0: %[[INDC:.*]] = index_cast %[[IND0]] : i64 to index +// CHECK-TYPE0: %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref +// CHECK-TYPE0: %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64> +// CHECK-TYPE0: %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64 +// CHECK-TYPE0: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64> +// CHECK-TYPE0: } + +// CHECK-TYPE1-LABEL: func @mul_dd( +// CHECK-TYPE1: %[[C0:.*]] = constant 0 : index +// CHECK-TYPE1: %[[C1:.*]] = constant 1 : index +// CHECK-TYPE1: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref +// CHECK-TYPE1: %[[B0:.*]] = index_cast %[[P0]] : i64 to index +// CHECK-TYPE1: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref +// CHECK-TYPE1: %[[B1:.*]] = index_cast %[[P1]] : i64 to index +// CHECK-TYPE1: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] { +// CHECK-TYPE1: %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref +// CHECK-TYPE1: %[[INDC:.*]] = index_cast %[[IND0]] : i32 to index +// CHECK-TYPE1: %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref +// CHECK-TYPE1: %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64> +// CHECK-TYPE1: %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64 +// CHECK-TYPE1: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64> +// CHECK-TYPE1: } + +// CHECK-TYPE2-LABEL: func @mul_dd( +// CHECK-TYPE2: %[[C0:.*]] = constant 0 : index +// CHECK-TYPE2: %[[C1:.*]] = constant 1 : index +// CHECK-TYPE2: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref +// CHECK-TYPE2: %[[B0:.*]] = index_cast %[[P0]] : i32 to index +// CHECK-TYPE2: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref +// CHECK-TYPE2: %[[B1:.*]] = index_cast %[[P1]] : i32 to index +// CHECK-TYPE2: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] { +// CHECK-TYPE2: %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref +// CHECK-TYPE2: %[[INDC:.*]] = index_cast %[[IND0]] : i64 to index +// CHECK-TYPE2: %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref +// CHECK-TYPE2: %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64> +// CHECK-TYPE2: %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64 +// CHECK-TYPE2: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64> +// CHECK-TYPE2: } + +// CHECK-TYPE3-LABEL: func @mul_dd( +// CHECK-TYPE3: %[[C0:.*]] = constant 0 : index +// CHECK-TYPE3: %[[C1:.*]] = constant 1 : index +// CHECK-TYPE3: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref +// CHECK-TYPE3: %[[B0:.*]] = index_cast %[[P0]] : i32 to index +// CHECK-TYPE3: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref +// CHECK-TYPE3: %[[B1:.*]] = index_cast %[[P1]] : i32 to index +// CHECK-TYPE3: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] { +// CHECK-TYPE3: %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref +// CHECK-TYPE3: %[[INDC:.*]] = index_cast %[[IND0]] : i32 to index +// CHECK-TYPE3: %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref +// CHECK-TYPE3: %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64> +// CHECK-TYPE3: %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64 +// CHECK-TYPE3: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64> +// CHECK-TYPE3: } + +func @mul_dd(%arga: tensor<32xf64>, %argb: tensor<32xf64>) -> tensor<32xf64> { + %0 = linalg.generic #trait_mul_1d + ins(%arga, %argb: tensor<32xf64>, tensor<32xf64>) { + ^bb(%a: f64, %b: f64): + %0 = mulf %a, %b : f64 + linalg.yield %0 : f64 + } -> tensor<32xf64> + return %0 : tensor<32xf64> +} + diff --git a/mlir/test/lib/Transforms/TestSparsification.cpp b/mlir/test/lib/Transforms/TestSparsification.cpp --- a/mlir/test/lib/Transforms/TestSparsification.cpp +++ b/mlir/test/lib/Transforms/TestSparsification.cpp @@ -31,6 +31,14 @@ Option vectorLength{ *this, "vl", llvm::cl::desc("Set the vector length"), llvm::cl::init(1)}; + Option ptrType{*this, "ptr-type", + llvm::cl::desc("Set the pointer type"), + llvm::cl::init(0)}; + + Option indType{*this, "ind-type", + llvm::cl::desc("Set the index type"), + llvm::cl::init(0)}; + /// Registers all dialects required by testing. void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -64,13 +72,26 @@ } } + /// Returns the requested integer type. + linalg::SparseIntType typeOption(int32_t option) { + switch (option) { + default: + return linalg::SparseIntType::kNative; + case 1: + return linalg::SparseIntType::kI64; + case 2: + return linalg::SparseIntType::kI32; + } + } + /// Runs the test on a function. void runOnFunction() override { auto *ctx = &getContext(); OwningRewritePatternList patterns; // Translate strategy flags to strategy options. linalg::SparsificationOptions options(parallelOption(), vectorOption(), - vectorLength); + vectorLength, typeOption(ptrType), + typeOption(indType)); // Apply rewriting. linalg::populateSparsificationPatterns(ctx, patterns, options); applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));