diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -232,12 +232,33 @@ LogicalResult matchAndRewrite(tensor::DimOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (!operands[0].getType().isa()) - return failure(); Type resType = op.getType(); + auto enc = getSparseTensorEncoding(op.source().getType()); + if (!enc) + return failure(); + // Permute the dim index. + Optional index = op.getConstantIndex(); + if (!index.hasValue()) + return failure(); + int64_t idx = index.getValue(); + AffineMap p = enc.getDimOrdering(); + if (p) { + assert(p.isPermutation()); + for (unsigned i = 0, sz = p.getNumResults(); i < sz; i++) { + if (p.getDimPosition(i) == idx) { + idx = i; + break; + } + } + } + // Generate the call. StringRef name = "sparseDimSize"; + SmallVector params; + params.push_back(operands[0]); + params.push_back( + rewriter.create(op.getLoc(), rewriter.getIndexAttr(idx))); rewriter.replaceOpWithNewOp( - op, resType, getFunc(op, name, resType, operands), operands); + op, resType, getFunc(op, name, resType, params), params); return success(); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -282,17 +282,11 @@ codegen.indices[tensor][idx] = rewriter.create(loc, indTp, t->get(), dim); } - // Find lower and upper bound in current dimension. Note that a - // permuted encoding queries static type dimensions accordingly, - // but queries dynamic type dimensions in the generated order. - Value up; + // Find upper bound in current dimension. unsigned p = perm(enc, d); - if (shape[p] == MemRefType::kDynamicSize) { - up = rewriter.create(loc, t->get(), d); + Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p); + if (shape[p] == MemRefType::kDynamicSize) args.push_back(up); - } else { - up = rewriter.create(loc, shape[p]); - } assert(codegen.highs[tensor][idx] == nullptr); codegen.sizes[idx] = codegen.highs[tensor][idx] = up; } diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -29,17 +29,29 @@ dimOrdering = affine_map<(i,j,k) -> (k,i,j)> }> -// CHECK-LABEL: func @sparse_dim( +// CHECK-LABEL: func @sparse_dim1d( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) // CHECK: %[[C:.*]] = constant 0 : index // CHECK: %[[D:.*]] = call @sparseDimSize(%[[A]], %[[C]]) // CHECK: return %[[D]] : index -func @sparse_dim(%arg0: tensor) -> index { +func @sparse_dim1d(%arg0: tensor) -> index { %c = constant 0 : index %0 = tensor.dim %arg0, %c : tensor return %0 : index } +// CHECK-LABEL: func @sparse_dim3d( +// CHECK-SAME: %[[A:.*]]: !llvm.ptr) +// CHECK: %[[C:.*]] = constant 2 : index +// CHECK: %[[D:.*]] = call @sparseDimSize(%[[A]], %[[C]]) +// CHECK: return %[[D]] : index +func @sparse_dim3d(%arg0: tensor) -> index { + // Needs permuting 1 into 2. + %c = constant 1 : index + %0 = tensor.dim %arg0, %c : tensor + return %0 : index +} + // CHECK-LABEL: func @sparse_new1d( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) -> !llvm.ptr // CHECK-DAG: %[[U:.*]] = constant dense<1> : tensor<1xi8> diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir @@ -0,0 +1,92 @@ +// RUN: mlir-opt %s -sparsification --canonicalize | FileCheck %s --check-prefix=CHECK-HIR +// +// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion --canonicalize | \ +// RUN: FileCheck %s --check-prefix=CHECK-MIR + +#X = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "dense", "dense" ], + dimOrdering = affine_map<(i,j,k) -> (k,i,j)> +}> + +#trait = { + indexing_maps = [ + affine_map<(i,j,k) -> (k,i,j)>, // A (in) + affine_map<(i,j,k) -> ()> // X (out) + ], + iterator_types = ["reduction", "reduction", "reduction"] +} + +// CHECK-HIR-LABEL: builtin.func @sparse_dynamic_dims( +// CHECK-HIR-SAME: %[[VAL_0:.*]]: tensor>, +// CHECK-HIR-SAME: %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK-HIR-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-HIR-DAG: %[[C1:.*]] = constant 1 : index +// CHECK-HIR-DAG: %[[C2:.*]] = constant 2 : index +// CHECK-HIR: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[C2]] : tensor> +// CHECK-HIR: %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor> +// CHECK-HIR: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[C1]] : tensor> +// CHECK-HIR: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor> +// CHECK-HIR: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref +// CHECK-HIR: %[[VAL_10:.*]] = memref.alloc() : memref +// CHECK-HIR: memref.copy %[[VAL_9]], %[[VAL_10]] : memref to memref +// CHECK-HIR: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[VAL_5]] step %[[C1]] { +// CHECK-HIR: scf.for %[[VAL_12:.*]] = %[[C0]] to %[[VAL_6]] step %[[C1]] { +// CHECK-HIR: %[[VAL_13:.*]] = muli %[[VAL_6]], %[[VAL_11]] : index +// CHECK-HIR: %[[VAL_14:.*]] = addi %[[VAL_13]], %[[VAL_12]] : index +// CHECK-HIR: %[[VAL_15:.*]] = memref.load %[[VAL_10]][] : memref +// CHECK-HIR: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[C0]] to %[[VAL_7]] step %[[C1]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f32) { +// CHECK-HIR: %[[VAL_19:.*]] = muli %[[VAL_7]], %[[VAL_14]] : index +// CHECK-HIR: %[[VAL_20:.*]] = addi %[[VAL_19]], %[[VAL_17]] : index +// CHECK-HIR: %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref +// CHECK-HIR: %[[VAL_22:.*]] = addf %[[VAL_18]], %[[VAL_21]] : f32 +// CHECK-HIR: scf.yield %[[VAL_22]] : f32 +// CHECK-HIR: } +// CHECK-HIR: memref.store %[[VAL_23:.*]], %[[VAL_10]][] : memref +// CHECK-HIR: } +// CHECK-HIR: } +// CHECK-HIR: %[[VAL_24:.*]] = memref.tensor_load %[[VAL_10]] : memref +// CHECK-HIR: return %[[VAL_24]] : tensor +// CHECK-HIR: } +// +// CHECK-MIR-LABEL: builtin.func @sparse_dynamic_dims( +// CHECK-MIR-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-MIR-SAME: %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK-MIR-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-MIR-DAG: %[[C1:.*]] = constant 1 : index +// CHECK-MIR-DAG: %[[C2:.*]] = constant 2 : index +// CHECK-MIR: %[[VAL_5:.*]] = call @sparseDimSize(%[[VAL_0]], %[[C0]]) : (!llvm.ptr, index) -> index +// CHECK-MIR: %[[VAL_6:.*]] = call @sparseDimSize(%[[VAL_0]], %[[C1]]) : (!llvm.ptr, index) -> index +// CHECK-MIR: %[[VAL_7:.*]] = call @sparseDimSize(%[[VAL_0]], %[[C2]]) : (!llvm.ptr, index) -> index +// CHECK-MIR: %[[VAL_8:.*]] = call @sparseValuesF32(%[[VAL_0]]) : (!llvm.ptr) -> memref +// CHECK-MIR: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref +// CHECK-MIR: %[[VAL_10:.*]] = memref.alloc() : memref +// CHECK-MIR: memref.copy %[[VAL_9]], %[[VAL_10]] : memref to memref +// CHECK-MIR: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[VAL_5]] step %[[C1]] { +// CHECK-MIR: scf.for %[[VAL_12:.*]] = %[[C0]] to %[[VAL_6]] step %[[C1]] { +// CHECK-MIR: %[[VAL_13:.*]] = muli %[[VAL_6]], %[[VAL_11]] : index +// CHECK-MIR: %[[VAL_14:.*]] = addi %[[VAL_13]], %[[VAL_12]] : index +// CHECK-MIR: %[[VAL_15:.*]] = memref.load %[[VAL_10]][] : memref +// CHECK-MIR: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[C0]] to %[[VAL_7]] step %[[C1]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f32) { +// CHECK-MIR: %[[VAL_19:.*]] = muli %[[VAL_7]], %[[VAL_14]] : index +// CHECK-MIR: %[[VAL_20:.*]] = addi %[[VAL_19]], %[[VAL_17]] : index +// CHECK-MIR: %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref +// CHECK-MIR: %[[VAL_22:.*]] = addf %[[VAL_18]], %[[VAL_21]] : f32 +// CHECK-MIR: scf.yield %[[VAL_22]] : f32 +// CHECK-MIR: } +// CHECK-MIR: memref.store %[[VAL_23:.*]], %[[VAL_10]][] : memref +// CHECK-MIR: } +// CHECK-MIR: } +// CHECK-MIR: %[[VAL_24:.*]] = memref.tensor_load %[[VAL_10]] : memref +// CHECK-MIR: return %[[VAL_24]] : tensor +// CHECK-MIR: } +func @sparse_dynamic_dims(%arga: tensor, + %argx: tensor) -> tensor { + %0 = linalg.generic #trait + ins(%arga: tensor) + outs(%argx: tensor) { + ^bb(%a : f32, %x: f32): + %0 = addf %x, %a : f32 + linalg.yield %0 : f32 + } -> tensor + return %0 : tensor +}