diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -282,14 +282,18 @@ codegen.indices[tensor][idx] = rewriter.create(loc, indTp, t->get(), dim); } - // Find lower and upper bound in current dimension. + // Find lower and upper bound in current dimension. Note that a + // permuted encoding queries static type dimensions accordingly, + // but queries dynamic type dimensions in the generated order. Value up; - if (shape[d] == MemRefType::kDynamicSize) { + unsigned p = perm(enc, d); + if (shape[p] == MemRefType::kDynamicSize) { up = rewriter.create(loc, t->get(), d); args.push_back(up); } else { - up = rewriter.create(loc, shape[d]); + up = rewriter.create(loc, shape[p]); } + assert(codegen.highs[tensor][idx] == nullptr); codegen.sizes[idx] = codegen.highs[tensor][idx] = up; } // Perform the required bufferization. Dense inputs materialize diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_perm.mlir @@ -0,0 +1,92 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py +// RUN: mlir-opt %s -sparsification | FileCheck %s + +#X = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "dense", "dense" ], + dimOrdering = affine_map<(i,j,k) -> (k,i,j)> +}> + +#trait = { + indexing_maps = [ + affine_map<(i,j,k) -> (k,i,j)>, // A (in) + affine_map<(i,j,k) -> (i,j,k)> // X (out) + ], + iterator_types = ["parallel", "parallel", "parallel"] +} + +// CHECK-LABEL: builtin.func @sparse_static_dims( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<10x20x30xf32, #sparse_tensor.encoding<{{{.*}}}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<20x30x10xf32>) -> tensor<20x30x10xf32> { +// CHECK: %[[VAL_2:.*]] = constant 20 : index +// CHECK: %[[VAL_3:.*]] = constant 30 : index +// CHECK: %[[VAL_4:.*]] = constant 10 : index +// CHECK: %[[VAL_5:.*]] = constant 0 : index +// CHECK: %[[VAL_6:.*]] = constant 1 : index +// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<20x30x10xf32> +// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<20x30x10xf32> +// CHECK: memref.copy %[[VAL_8]], %[[VAL_9]] : memref<20x30x10xf32> to memref<20x30x10xf32> +// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] { +// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { +// CHECK: %[[VAL_12:.*]] = muli %[[VAL_10]], %[[VAL_4]] : index +// CHECK: %[[VAL_13:.*]] = addi %[[VAL_12]], %[[VAL_11]] : index +// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_6]] { +// CHECK: %[[VAL_15:.*]] = muli %[[VAL_13]], %[[VAL_2]] : index +// CHECK: %[[VAL_16:.*]] = addi %[[VAL_15]], %[[VAL_14]] : index +// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref +// CHECK: memref.store %[[VAL_17]], %[[VAL_9]]{{\[}}%[[VAL_14]], %[[VAL_10]], %[[VAL_11]]] : memref<20x30x10xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_18:.*]] = memref.tensor_load %[[VAL_9]] : memref<20x30x10xf32> +// CHECK: return %[[VAL_18]] : tensor<20x30x10xf32> +// CHECK: } +func @sparse_static_dims(%arga: tensor<10x20x30xf32, #X>, + %argx: tensor<20x30x10xf32>) -> tensor<20x30x10xf32> { + %0 = linalg.generic #trait + ins(%arga: tensor<10x20x30xf32, #X>) + outs(%argx: tensor<20x30x10xf32>) { + ^bb(%a : f32, %x: f32): + linalg.yield %a : f32 + } -> tensor<20x30x10xf32> + return %0 : tensor<20x30x10xf32> +} + +// CHECK-LABEL: builtin.func @sparse_dynamic_dims( +// CHECK-SAME: %[[VAL_0:.*]]: tensor>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = constant 2 : index +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_0]] : tensor> +// CHECK: %[[VAL_6:.*]] = tensor.dim %[[VAL_1]], %[[VAL_3]] : tensor +// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[VAL_2]] : tensor +// CHECK: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref +// CHECK: %[[VAL_10:.*]] = memref.alloc(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) : memref +// CHECK: memref.copy %[[VAL_9]], %[[VAL_10]] : memref to memref +// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_4]] { +// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_4]] { +// CHECK: %[[VAL_13:.*]] = muli %[[VAL_8]], %[[VAL_11]] : index +// CHECK: %[[VAL_14:.*]] = addi %[[VAL_13]], %[[VAL_12]] : index +// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_4]] { +// CHECK: %[[VAL_16:.*]] = muli %[[VAL_6]], %[[VAL_14]] : index +// CHECK: %[[VAL_17:.*]] = addi %[[VAL_16]], %[[VAL_15]] : index +// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_17]]] : memref +// CHECK: memref.store %[[VAL_18]], %[[VAL_10]]{{\[}}%[[VAL_15]], %[[VAL_11]], %[[VAL_12]]] : memref +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_19:.*]] = memref.tensor_load %[[VAL_10]] : memref +// CHECK: return %[[VAL_19]] : tensor +// CHECK: } +func @sparse_dynamic_dims(%arga: tensor, + %argx: tensor) -> tensor { + %0 = linalg.generic #trait + ins(%arga: tensor) + outs(%argx: tensor) { + ^bb(%a : f32, %x: f32): + linalg.yield %a : f32 + } -> tensor + return %0 : tensor +}