diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -169,6 +169,7 @@ let dependentDialects = [ "arith::ArithmeticDialect", "bufferization::BufferizationDialect", + "linalg::LinalgDialect", "memref::MemRefDialect", "scf::SCFDialect", "sparse_tensor::SparseTensorDialect", diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" @@ -447,6 +448,53 @@ } }; + +/// Sparse codegen rule for the expand op. +class SparseExpandConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ExpandOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + ShapedType srcType = op.getTensor().getType().cast(); + Type eltType = srcType.getElementType(); + Type boolType = rewriter.getIntegerType(1); + Type idxType = rewriter.getIndexType(); + // All initialization should be done on entry of the loop nest. + rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); + Value src = op.getTensor(); + Value sz = rewriter.create(loc, src, srcType.getRank() - 1); + // Generate a memref for `sz` elements of type `t`. + auto genAlloc = [&](Type t) { + auto memTp = MemRefType::get({ShapedType::kDynamicSize}, t); + return rewriter.create(loc, memTp, ValueRange{sz}); + }; + // Allocate temporary buffers for values, filled-switch, and indices. + // We do not use stack buffers for this, since the expanded size may + // be rather large (as it envelops a single expanded dense dimension). + Value values = genAlloc(eltType); + Value filled = genAlloc(boolType); + Value indices = genAlloc(idxType); + Value zero = constantZero(rewriter, loc, idxType); + // Reset the values/filled-switch to all-zero/false. Note that this + // introduces an O(N) operation into the computation, but this reset + // operation is amortized over the innermost loops for the access + // pattern expansion. As noted in the operation doc, we would like + // to amortize this setup cost even between kernels. + rewriter.create( + loc, ValueRange{constantZero(rewriter, loc, eltType)}, + ValueRange{values}); + rewriter.create( + loc, ValueRange{constantZero(rewriter, loc, boolType)}, + ValueRange{filled}); + // Replace expansion op with these buffers and initial index. + assert(op.getNumResults() == 4); + rewriter.replaceOp(op, {values, filled, indices, zero}); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -467,8 +515,8 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add( - typeConverter, patterns.getContext()); + SparseExpandConverter, SparseTensorAllocConverter, + SparseTensorDeallocConverter, SparseToPointersConverter, + SparseToIndicesConverter, SparseToValuesConverter, + SparseTensorLoadConverter>(typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -177,6 +177,10 @@ [&](bufferization::DeallocTensorOp op) { return converter.isLegal(op.getTensor().getType()); }); + // The following operations and dialects may be introduced by the + // rewriting rules, and are therefore marked as legal. + target.addLegalOp(); + // Legal dialects may occur in generated code. target.addLegalDialect return %1 : tensor<10x20x30xf64, #Dense3D> } + +// CHECK-LABEL: func.func @sparse_expansion() +// CHECK: %[[A:.*]] = memref.alloc() : memref<8xf64> +// CHECK: %[[B:.*]] = memref.alloc() : memref<8xi1> +// CHECK: %[[CStatic:.*]] = memref.alloc() : memref<8xindex> +// CHECK: %[[C:.*]] = memref.cast %[[CStatic]] : memref<8xindex> to memref +// CHECK-DAG: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref<8xf64>) +// CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<8xi1>) +// CHECK: return %[[C]] : memref +func.func @sparse_expansion() -> memref { + %0 = bufferization.alloc_tensor() : tensor<8x8xf64, #CSR> + %values, %filled, %added, %count = sparse_tensor.expand %0 + : tensor<8x8xf64, #CSR> to memref, memref, memref, index + return %added : memref +} +