diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -16,7 +16,9 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/ExecutionEngine/SparseTensor/Enums.h" +#include "mlir/ExecutionEngine/SparseTensorUtils.h" #include "mlir/IR/Builders.h" namespace mlir { @@ -193,6 +195,11 @@ static_cast(dimLevelTypeEncoding(dlt))); } +void translateIndicesArray(OpBuilder &builder, Location loc, + ArrayRef reassociation, + ValueRange srcIndices, ArrayRef srcShape, + ArrayRef dstShape, + SmallVectorImpl &dstIndices); } // namespace sparse_tensor } // namespace mlir diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -199,3 +199,53 @@ return builder.create(loc, v, zero); llvm_unreachable("Non-numeric type"); } + +/// Helper method to translate indices during a reshaping operation. +void mlir::sparse_tensor::translateIndicesArray( + OpBuilder &builder, Location loc, + ArrayRef reassociation, ValueRange srcIndices, + ArrayRef srcShape, ArrayRef dstShape, + SmallVectorImpl &dstIndices) { + unsigned i = 0; + unsigned start = 0; + unsigned dstRank = dstShape.size(); + unsigned srcRank = srcShape.size(); + assert(srcRank == srcIndices.size()); + bool isCollapse = srcRank > dstRank; + ArrayRef shape = isCollapse ? srcShape : dstShape; + // Iterate over reassociation map. + for (const auto &map : llvm::enumerate(reassociation)) { + // Prepare strides information in dimension slice. + Value linear = constantIndex(builder, loc, 1); + for (unsigned j = start, end = start + map.value().size(); j < end; j++) { + linear = builder.create(loc, linear, shape[j]); + } + // Start expansion. + Value val; + if (!isCollapse) + val = srcIndices[i]; + // Iterate over dimension slice. + for (unsigned j = start, end = start + map.value().size(); j < end; j++) { + linear = builder.create(loc, linear, shape[j]); + if (isCollapse) { + Value old = srcIndices[j]; + Value mul = builder.create(loc, old, linear); + val = val ? builder.create(loc, val, mul) : mul; + } else { + Value old = val; + val = builder.create(loc, val, linear); + assert(dstIndices.size() == j); + dstIndices.push_back(val); + val = builder.create(loc, old, linear); + } + } + // Finalize collapse. + if (isCollapse) { + assert(dstIndices.size() == i); + dstIndices.push_back(val); + } + start += map.value().size(); + i++; + } + assert(dstIndices.size() == dstRank); +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -475,44 +475,21 @@ ArrayRef srcShape) { unsigned dstRank = dstTp.getRank(); unsigned srcRank = srcTp.getRank(); - unsigned start = 0; - unsigned i = 0; - bool isExpand = srcRank > dstRank; - ArrayRef shape = isExpand ? srcShape : dstShape; - // Iterate over reassociation map. - for (const auto &map : llvm::enumerate(reassociation)) { - // Prepare strides information in dimension slice. - Value linear = constantIndex(rewriter, loc, 1); - for (unsigned j = start, end = start + map.value().size(); j < end; j++) { - linear = rewriter.create(loc, linear, shape[j]); - } - // Start collapse. - Value idx = constantIndex(rewriter, loc, i++); - Value val; - if (!isExpand) - val = rewriter.create(loc, srcIdx, idx); - // Iterate over dimension slice. - for (unsigned j = start, end = start + map.value().size(); j < end; j++) { - linear = rewriter.create(loc, linear, shape[j]); - Value jdx = constantIndex(rewriter, loc, j); - if (isExpand) { - Value old = rewriter.create(loc, srcIdx, jdx); - Value mul = rewriter.create(loc, old, linear); - val = val ? rewriter.create(loc, val, mul) : mul; - } else { - Value old = val; - val = rewriter.create(loc, val, linear); - rewriter.create(loc, val, dstIdx, jdx); - val = rewriter.create(loc, old, linear); - } - } - // Finalize expansion. - if (isExpand) - rewriter.create(loc, val, dstIdx, idx); - start += map.value().size(); + + SmallVector srcIndices; + for (unsigned i = 0; i < srcRank; i++) { + Value idx = rewriter.create( + loc, srcIdx, constantIndex(rewriter, loc, i)); + srcIndices.push_back(idx); } - // Sanity. - assert((isExpand && i == dstRank) || (!isExpand && i == srcRank)); + + SmallVector dstIndices; + translateIndicesArray(rewriter, loc, reassociation, srcIndices, srcShape, + dstShape, dstIndices); + + for (unsigned i = 0; i < dstRank; i++) + rewriter.create(loc, dstIndices[i], dstIdx, + constantIndex(rewriter, loc, i)); } /// Helper method to compute the shape of destination tensor of a reshape diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir @@ -26,8 +26,8 @@ // CHECK-CONV: } do { // CHECK-CONV: %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<1xindex> // CHECK-CONV: %[[D:.*]] = arith.divui %[[X]], %[[C10]] : index -// CHECK-CONV: memref.store %[[D]], %{{.*}}[%[[C0]]] : memref<2xindex> // CHECK-CONV: %[[R:.*]] = arith.remui %[[X]], %[[C10]] : index +// CHECK-CONV: memref.store %[[D]], %{{.*}}[%[[C0]]] : memref<2xindex> // CHECK-CONV: memref.store %[[R]], %{{.*}}[%[[C1]]] : memref<2xindex> // CHECK-CONV: call @addEltF64 // CHECK-CONV: scf.yield @@ -64,8 +64,8 @@ // CHECK-CONV: scf.condition // CHECK-CONV: } do { // CHECK-CONV: %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<2xindex> -// CHECK-CONV: %[[M:.*]] = arith.muli %[[X]], %[[C10]] : index // CHECK-CONV: %[[Y:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<2xindex> +// CHECK-CONV: %[[M:.*]] = arith.muli %[[X]], %[[C10]] : index // CHECK-CONV: %[[A:.*]] = arith.addi %[[M]], %[[Y]] : index // CHECK-CONV: memref.store %[[A]], %{{.*}}[%[[C0]]] : memref<1xindex> // CHECK-CONV: call @addEltF64 @@ -103,14 +103,14 @@ // CHECK-CONV: call @getNextF64 // CHECK-CONV: scf.condition // CHECK-CONV: } do { -// CHECK-CONV: %[[M:.*]] = arith.muli %[[D1]], %[[C10]] : index // CHECK-CONV: %[[L:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<1xindex> +// CHECK-CONV: %[[M:.*]] = arith.muli %[[D1]], %[[C10]] : index // CHECK-CONV: %[[D2:.*]] = arith.divui %[[M]], %[[D1]] : index // CHECK-CONV: %[[D3:.*]] = arith.divui %[[L]], %[[D2]] : index -// CHECK-CONV: memref.store %[[D3]], %{{.*}}[%[[C0]]] : memref<2xindex> // CHECK-CONV: %[[R:.*]] = arith.remui %[[L]], %[[D2]] : index // CHECK-CONV: %[[D4:.*]] = arith.divui %[[D2]], %[[C10]] : index // CHECK-CONV: %[[D5:.*]] = arith.divui %[[R]], %[[D4]] : index +// CHECK-CONV: memref.store %[[D3]], %{{.*}}[%[[C0]]] : memref<2xindex> // CHECK-CONV: memref.store %[[D5]], %{{.*}}[%[[C1]]] : memref<2xindex> // CHECK-CONV: call @addEltF64 // CHECK-CONV: scf.yield @@ -147,11 +147,11 @@ // CHECK-CONV: call @getNextF64 // CHECK-CONV: scf.condition // CHECK-CONV: } do { -// CHECK-CONV: %[[D1:.*]] = arith.divui %[[M1]], %[[C10]] : index // CHECK-CONV: %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<2xindex> +// CHECK-CONV: %[[Y:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<2xindex> +// CHECK-CONV: %[[D1:.*]] = arith.divui %[[M1]], %[[C10]] : index // CHECK-CONV: %[[M2:.*]] = arith.muli %[[X]], %[[D1]] : index // CHECK-CONV: %[[D2:.*]] = arith.divui %[[D1]], %{{.*}} : index -// CHECK-CONV: %[[Y:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<2xindex> // CHECK-CONV: %[[M3:.*]] = arith.muli %[[Y]], %[[D2]] : index // CHECK-CONV: %[[A:.*]] = arith.addi %[[M2]], %[[M3]] : index // CHECK-CONV: memref.store %[[A]], %{{.*}}[%[[C0]]] : memref<1xindex>