diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -48,16 +48,15 @@ /// Sparsification options. struct SparsificationOptions { SparsificationOptions(SparseParallelizationStrategy p, - SparseVectorizationStrategy v, unsigned vl, bool fo) - : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl), - fastOutput(fo) {} + SparseVectorizationStrategy v, unsigned vl) + : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl) { + } SparsificationOptions() : SparsificationOptions(SparseParallelizationStrategy::kNone, - SparseVectorizationStrategy::kNone, 1u, false) {} + SparseVectorizationStrategy::kNone, 1u) {} SparseParallelizationStrategy parallelizationStrategy; SparseVectorizationStrategy vectorizationStrategy; unsigned vectorLength; - bool fastOutput; // experimental: fast output buffers }; /// Sets up sparsification rewriting rules with the given options. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -45,10 +45,6 @@ Option vectorLength{ *this, "vl", llvm::cl::desc("Set the vector length"), llvm::cl::init(1)}; - Option fastOutput{*this, "fast-output", - llvm::cl::desc("Allows fast output buffers"), - llvm::cl::init(false)}; - /// Returns parallelization strategy given on command line. SparseParallelizationStrategy parallelOption() { switch (parallelization) { @@ -82,7 +78,7 @@ RewritePatternSet patterns(ctx); // Translate strategy flags to strategy options. SparsificationOptions options(parallelOption(), vectorOption(), - vectorLength, fastOutput); + vectorLength); // Apply rewriting. populateSparsificationPatterns(patterns, options); vector::populateVectorToVectorCanonicalizationPatterns(patterns); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -506,6 +506,16 @@ return rewriter.getIntegerType(width); } +/// Detects in-place annotation on tensor argument. +static bool getInPlace(Value val) { + if (auto arg = val.dyn_cast()) + if (auto funcOp = dyn_cast(arg.getOwner()->getParentOp())) + if (auto attr = funcOp.getArgAttrOfType( + arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName)) + return attr.getValue(); + return false; +} + /// Generates buffer for the output tensor. static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, MemRefType denseTp, @@ -515,9 +525,8 @@ // The output tensor simply could materialize from the buffer that will // be generated for the tensor present in the outs() clause. This has // the major advantage that the sparse kernel only updates the nonzero - // positions for the output tensor. Currently this results in functional, - // but slightly imprecise IR, so it is put under an experimental option. - if (codegen.options.fastOutput) + // positions for the output tensor. + if (getInPlace(tensor)) return rewriter.create(loc, denseTp, tensor); // By default, a new buffer is allocated which is initialized to the // tensor defined in the outs() clause. This is always correct but diff --git a/mlir/test/Dialect/SparseTensor/sparse_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_lower.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_lower.mlir @@ -7,11 +7,6 @@ // RUN: --convert-linalg-to-loops --func-bufferize --tensor-constant-bufferize \ // RUN: --tensor-bufferize --finalizing-bufferize | \ // RUN: FileCheck %s --check-prefix=CHECK-LIR -// -// RUN: mlir-opt %s -sparsification="fast-output" --sparse-tensor-conversion \ -// RUN: --convert-linalg-to-loops --func-bufferize --tensor-constant-bufferize \ -// RUN: --tensor-bufferize --finalizing-bufferize | \ -// RUN: FileCheck %s --check-prefix=CHECK-FAST #CSR = #sparse_tensor.encoding<{dimLevelType = [ "dense", "compressed" ]}> @@ -32,9 +27,9 @@ // CHECK-HIR: %[[VAL_3:.*]] = constant 64 : index // CHECK-HIR: %[[VAL_4:.*]] = constant 0 : index // CHECK-HIR: %[[VAL_5:.*]] = constant 1 : index -// CHECK-HIR: %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<64x64xf64, #{{.*}}> to memref -// CHECK-HIR: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<64x64xf64, #{{.*}}> to memref -// CHECK-HIR: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<64x64xf64, #{{.*}}> to memref +// CHECK-HIR: %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK-HIR: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK-HIR: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref // CHECK-HIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64> // CHECK-HIR: %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<64xf64> // CHECK-HIR: %[[VAL_12:.*]] = memref.alloc() : memref<64xf64> @@ -127,35 +122,9 @@ // CHECK-LIR: return %[[VAL_9]] : memref<64xf64> // CHECK-LIR: } -// CHECK-FAST-LABEL: func @matvec( -// CHECK-FAST-SAME: %[[VAL_0:.*]]: !llvm.ptr, -// CHECK-FAST-SAME: %[[VAL_1:.*]]: memref<64xf64>, -// CHECK-FAST-SAME: %[[VAL_2:.*]]: memref<64xf64>) -> memref<64xf64> { -// CHECK-FAST: %[[VAL_3:.*]] = constant 64 : index -// CHECK-FAST: %[[VAL_4:.*]] = constant 0 : index -// CHECK-FAST: %[[VAL_5:.*]] = constant 1 : index -// CHECK-FAST: %[[VAL_6:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref -// CHECK-FAST: %[[VAL_7:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref -// CHECK-FAST: %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref -// CHECK-FAST: scf.for %[[VAL_9:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { -// CHECK-FAST: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_9]]] : memref -// CHECK-FAST: %[[VAL_11:.*]] = addi %[[VAL_9]], %[[VAL_5]] : index -// CHECK-FAST: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref -// CHECK-FAST: %[[VAL_13:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<64xf64> -// CHECK-FAST: %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_10]] to %[[VAL_12]] step %[[VAL_5]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (f64) { -// CHECK-FAST: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref -// CHECK-FAST: %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref -// CHECK-FAST: %[[VAL_19:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_17]]] : memref<64xf64> -// CHECK-FAST: %[[VAL_20:.*]] = mulf %[[VAL_18]], %[[VAL_19]] : f64 -// CHECK-FAST: %[[VAL_21:.*]] = addf %[[VAL_16]], %[[VAL_20]] : f64 -// CHECK-FAST: scf.yield %[[VAL_21]] : f64 -// CHECK-FAST: } -// CHECK-FAST: store %[[VAL_22:.*]], %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<64xf64> -// CHECK-FAST: } -// CHECK-FAST: return %[[VAL_2]] : memref<64xf64> -// CHECK-FAST: } - -func @matvec(%arga: tensor<64x64xf64, #CSR>, %argb: tensor<64xf64>, %argx: tensor<64xf64>) -> tensor<64xf64> { +func @matvec(%arga: tensor<64x64xf64, #CSR>, + %argb: tensor<64xf64>, + %argx: tensor<64xf64>) -> tensor<64xf64> { %0 = linalg.generic #trait_matvec ins(%arga, %argb : tensor<64x64xf64, #CSR>, tensor<64xf64>) outs(%argx: tensor<64xf64>) { diff --git a/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir @@ -0,0 +1,125 @@ +// RUN: mlir-opt %s -sparsification | FileCheck %s --check-prefix=CHECK-HIR +// +// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion \ +// RUN: --convert-linalg-to-loops | FileCheck %s --check-prefix=CHECK-MIR +// +// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion \ +// RUN: --convert-linalg-to-loops --func-bufferize --tensor-constant-bufferize \ +// RUN: --tensor-bufferize --finalizing-bufferize | \ +// RUN: FileCheck %s --check-prefix=CHECK-LIR + +#CSR = #sparse_tensor.encoding<{dimLevelType = [ "dense", "compressed" ]}> + +#trait_matvec = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (j)>, // b + affine_map<(i,j) -> (i)> // x (out) + ], + iterator_types = ["parallel","reduction"], + doc = "x(i) += A(i,j) * b(j)" +} + +// CHECK-HIR-LABEL: func @matvec( +// CHECK-HIR-SAME: %[[VAL_0:.*]]: tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>>, +// CHECK-HIR-SAME: %[[VAL_1:.*]]: tensor<64xf64>, +// CHECK-HIR-SAME: %[[VAL_2:.*]]: tensor<64xf64> {linalg.inplaceable = true}) -> tensor<64xf64> { +// CHECK-HIR: %[[VAL_3:.*]] = constant 64 : index +// CHECK-HIR: %[[VAL_4:.*]] = constant 0 : index +// CHECK-HIR: %[[VAL_5:.*]] = constant 1 : index +// CHECK-HIR: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK-HIR: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK-HIR: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK-HIR: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64> +// CHECK-HIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<64xf64> +// CHECK-HIR: scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { +// CHECK-HIR: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref +// CHECK-HIR: %[[VAL_13:.*]] = addi %[[VAL_11]], %[[VAL_5]] : index +// CHECK-HIR: %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref +// CHECK-HIR: %[[VAL_15:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<64xf64> +// CHECK-HIR: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f64) { +// CHECK-HIR: %[[VAL_19:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_17]]] : memref +// CHECK-HIR: %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref +// CHECK-HIR: %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref<64xf64> +// CHECK-HIR: %[[VAL_22:.*]] = mulf %[[VAL_20]], %[[VAL_21]] : f64 +// CHECK-HIR: %[[VAL_23:.*]] = addf %[[VAL_18]], %[[VAL_22]] : f64 +// CHECK-HIR: scf.yield %[[VAL_23]] : f64 +// CHECK-HIR: } +// CHECK-HIR: memref.store %[[VAL_24:.*]], %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<64xf64> +// CHECK-HIR: } +// CHECK-HIR: %[[VAL_25:.*]] = memref.tensor_load %[[VAL_10]] : memref<64xf64> +// CHECK-HIR: return %[[VAL_25]] : tensor<64xf64> +// CHECK-HIR: } + +// CHECK-MIR-LABEL: func @matvec( +// CHECK-MIR-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-MIR-SAME: %[[VAL_1:.*]]: tensor<64xf64>, +// CHECK-MIR-SAME: %[[VAL_2:.*]]: tensor<64xf64> {linalg.inplaceable = true}) -> tensor<64xf64> { +// CHECK-MIR: %[[VAL_3:.*]] = constant 64 : index +// CHECK-MIR: %[[VAL_4:.*]] = constant 0 : index +// CHECK-MIR: %[[VAL_5:.*]] = constant 1 : index +// CHECK-MIR: %[[VAL_6:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref +// CHECK-MIR: %[[VAL_7:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref +// CHECK-MIR: %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref +// CHECK-MIR: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64> +// CHECK-MIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<64xf64> +// CHECK-MIR: scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { +// CHECK-MIR: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref +// CHECK-MIR: %[[VAL_13:.*]] = addi %[[VAL_11]], %[[VAL_5]] : index +// CHECK-MIR: %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref +// CHECK-MIR: %[[VAL_15:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<64xf64> +// CHECK-MIR: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f64) { +// CHECK-MIR: %[[VAL_19:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_17]]] : memref +// CHECK-MIR: %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref +// CHECK-MIR: %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref<64xf64> +// CHECK-MIR: %[[VAL_22:.*]] = mulf %[[VAL_20]], %[[VAL_21]] : f64 +// CHECK-MIR: %[[VAL_23:.*]] = addf %[[VAL_18]], %[[VAL_22]] : f64 +// CHECK-MIR: scf.yield %[[VAL_23]] : f64 +// CHECK-MIR: } +// CHECK-MIR: memref.store %[[VAL_24:.*]], %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<64xf64> +// CHECK-MIR: } +// CHECK-MIR: %[[VAL_25:.*]] = memref.tensor_load %[[VAL_10]] : memref<64xf64> +// CHECK-MIR: return %[[VAL_25]] : tensor<64xf64> +// CHECK-MIR: } + +// CHECK-LIR-LABEL: func @matvec( +// CHECK-LIR-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-LIR-SAME: %[[VAL_1:.*]]: memref<64xf64>, +// CHECK-LIR-SAME: %[[VAL_2:.*]]: memref<64xf64> {linalg.inplaceable = true}) -> memref<64xf64> { +// CHECK-LIR: %[[VAL_3:.*]] = constant 64 : index +// CHECK-LIR: %[[VAL_4:.*]] = constant 0 : index +// CHECK-LIR: %[[VAL_5:.*]] = constant 1 : index +// CHECK-LIR: %[[VAL_6:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref +// CHECK-LIR: %[[VAL_7:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref +// CHECK-LIR: %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref +// CHECK-LIR: scf.for %[[VAL_9:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { +// CHECK-LIR: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_9]]] : memref +// CHECK-LIR: %[[VAL_11:.*]] = addi %[[VAL_9]], %[[VAL_5]] : index +// CHECK-LIR: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref +// CHECK-LIR: %[[VAL_13:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<64xf64> +// CHECK-LIR: %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_10]] to %[[VAL_12]] step %[[VAL_5]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (f64) { +// CHECK-LIR: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref +// CHECK-LIR: %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref +// CHECK-LIR: %[[VAL_19:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_17]]] : memref<64xf64> +// CHECK-LIR: %[[VAL_20:.*]] = mulf %[[VAL_18]], %[[VAL_19]] : f64 +// CHECK-LIR: %[[VAL_21:.*]] = addf %[[VAL_16]], %[[VAL_20]] : f64 +// CHECK-LIR: scf.yield %[[VAL_21]] : f64 +// CHECK-LIR: } +// CHECK-LIR: memref.store %[[VAL_22:.*]], %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<64xf64> +// CHECK-LIR: } +// CHECK-LIR: return %[[VAL_2]] : memref<64xf64> +// CHECK-LIR: } + +func @matvec(%arga: tensor<64x64xf64, #CSR>, + %argb: tensor<64xf64>, + %argx: tensor<64xf64> {linalg.inplaceable = true}) -> tensor<64xf64> { + %0 = linalg.generic #trait_matvec + ins(%arga, %argb : tensor<64x64xf64, #CSR>, tensor<64xf64>) + outs(%argx: tensor<64xf64>) { + ^bb(%A: f64, %b: f64, %x: f64): + %0 = mulf %A, %b : f64 + %1 = addf %x, %0 : f64 + linalg.yield %1 : f64 + } -> tensor<64xf64> + return %0 : tensor<64xf64> +} diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_matmul.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_matmul.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_matmul.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: --sparsification="fast-output" --sparse-tensor-conversion \ +// RUN: --sparsification --sparse-tensor-conversion \ // RUN: --convert-linalg-to-loops --convert-vector-to-scf --convert-scf-to-std \ // RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \ // RUN: --std-bufferize --finalizing-bufferize \ @@ -41,7 +41,7 @@ func @sampled_dense_dense(%args: tensor, %arga: tensor, %argb: tensor, - %argx: tensor) -> tensor { + %argx: tensor {linalg.inplaceable = true}) -> tensor { %0 = linalg.generic #trait_sampled_dense_dense ins(%args, %arga, %argb: tensor, tensor, tensor) outs(%argx: tensor) {