diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -624,79 +624,4 @@ let hasVerifier = 1; } -//===----------------------------------------------------------------------===// -// Sparse Tensor Storage Operation. These operations are used internally by -// sparse tensor codegen to progressively lower sparse tensors. -//===----------------------------------------------------------------------===// - -def SparseTensor_StorageOp : SparseTensor_Op<"storage", []>, - Arguments<(ins Variadic:$inputs)>, - Results<(outs AnyTuple:$result)> { - let summary = "Pack a list of value into one sparse tensor storage value"; - let description = [{ - Pack a list of value into one sparse tensor storage value (represented as - a tuple) at the given index. - - The result tuple elements' type should match the corresponding type in the - input array. - - Example: - - ```mlir - %0 = sparse_tensor.storage(%1, %2): memref, memref - to tuple, memref> - ``` - }]; - - let assemblyFormat = " attr-dict `(` $inputs `)``:` type($inputs) `to` type($result)"; - let hasVerifier = 1; -} - -def SparseTensor_StorageGetOp : SparseTensor_Op<"storage_get", []>, - Arguments<(ins AnyTuple:$storage, - IndexAttr:$idx)>, - Results<(outs AnyType:$result)> { - let summary = "Get the data stored in the sparse tensor storage at the given index"; - let description = [{ - Get the data stored in the sparse tensor storage (represented as a tuple) - at the given index. - - The result type should match the corresponding element type in the tuple. - - Example: - - ```mlir - %0 = sparse_tensor.storage_get %arg0[0] : tuple, memref, f64> to memref - ``` - }]; - - let assemblyFormat = " $storage attr-dict `[`$idx`]` `:` type($storage) `to` type($result)"; - let hasVerifier = 1; -} - -def SparseTensor_StorageSetOp : SparseTensor_Op<"storage_set", []>, - Arguments<(ins AnyTuple:$storage, - AnyType:$value, - IndexAttr:$idx)>, - Results<(outs AnyTuple:$result)> { - let summary = "Set the data stored in the sparse tensor storage at given index"; - let description = [{ - Set the data stored in the sparse tensor storage (represented as a tuple) - at the given index. Return a new SSA value with the corresponding element - updated (others remain unchanged). - - The result type should match the original tuple type with only the updated - element type changed accordingly. - - Example: - - ```mlir - %0 = sparse_tensor.storage_set %arg0, %arg1 at 0 : tuple, memref, f64>, memref to tuple, memref, f64> - ``` - }]; - - let assemblyFormat = " $storage attr-dict `[`$idx`]``,` $value `:` type($storage) `,` type($value) `to` type($result)"; - let hasVerifier = 1; -} - #endif // SPARSETENSOR_OPS diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -155,22 +155,6 @@ std::unique_ptr createSparseTensorCodegenPass(); -//===----------------------------------------------------------------------===// -// The SparseTensorStorageExpansion pass. -//===----------------------------------------------------------------------===// - -/// Sparse tensor storage type converter from compound to expanded form. -class SparseTensorStorageTupleExpander : public TypeConverter { -public: - SparseTensorStorageTupleExpander(); -}; - -/// Sets up sparse tensor storage expansion rules. -void populateSparseTensorStorageExpansionPatterns(TypeConverter &typeConverter, - RewritePatternSet &patterns); - -std::unique_ptr createSparseTensorStorageExpansionPass(); - //===----------------------------------------------------------------------===// // Other rewriting rules and passes. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -175,39 +175,4 @@ ]; } -def SparseTensorStorageExpansion : Pass<"sparse-tensor-storage-expansion", "ModuleOp"> { - let summary = "Expand compounded sparse tensor storage into individual SSA values"; - let description = [{ - A pass that expands sparse tensor storage (aggregated by tuple) into - individual SSA values. It also lowers sparse tensor storage operations, - e.g., sparse_tensor.storage_get and sparse_tensor.storage_set. - - Example of the conversion: - - ```mlir - Before: - func.func @sparse_storage_set(%arg0: tuple, - memref, - f64>) - -> tuple, - memref, - f64> { - return %arg0 : tuple, memref, f64> - } - After: - func.func @sparse_storage_set(%arg0: memref, - %arg1: memref, - %arg2: f64) - -> (memref, memref, f64) { - return %arg0, %arg1, %arg2 : memref, memref, f64 - } - ``` - }]; - let constructor = "mlir::createSparseTensorStorageExpansionPass()"; - let dependentDialects = [ - "sparse_tensor::SparseTensorDialect", - ]; -} - - #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -482,65 +482,6 @@ "expected parent op to be sparse_tensor unary, binary, or reduce"); } -//===----------------------------------------------------------------------===// -// Sparse Tensor Storage Operation. -//===----------------------------------------------------------------------===// - -LogicalResult StorageOp::verify() { - auto retTypes = getResult().getType().getTypes(); - if (retTypes.size() != getInputs().size()) - return emitError("The number of inputs is inconsistent with output tuple"); - - for (auto pair : llvm::zip(getInputs(), retTypes)) { - auto input = std::get<0>(pair); - auto retTy = std::get<1>(pair); - - if (input.getType() != retTy) - return emitError(llvm::formatv("Type mismatch between input (type={0}) " - "and output tuple element (type={1})", - input.getType(), retTy)); - } - return success(); -} - -LogicalResult StorageGetOp::verify() { - uint64_t extractIdx = getIdx().getZExtValue(); - auto innerTypeArray = getStorage().getType().getTypes(); - if (extractIdx >= innerTypeArray.size()) - return emitError(llvm::formatv( - "Out-of-bound access with index={0} on tuple with length={1}", - extractIdx, innerTypeArray.size())); - - auto expectedTy = getStorage().getType().getType(extractIdx); - auto returnTy = getResult().getType(); - if (expectedTy != returnTy) - return emitError(llvm::formatv( - "Type mismatch between the returning type (type={0}) and the " - "corresponding element type at index {1} (type={2})", - expectedTy, extractIdx, returnTy)); - return success(); -} - -LogicalResult StorageSetOp::verify() { - uint64_t setIdx = getIdx().getZExtValue(); - SmallVector expectedElemTy(getStorage().getType().getTypes()); - if (setIdx >= expectedElemTy.size()) - return emitError(llvm::formatv( - "Out-of-bound access with index = {0} on tuple with length={1}", setIdx, - expectedElemTy.size())); - - // Updates the element type after storage_set. - expectedElemTy[setIdx] = getValue().getType(); - auto expectedTy = TupleType::get(getContext(), expectedElemTy); - auto returnTy = getResult().getType(); - if (expectedTy != returnTy) - return emitError( - llvm::formatv("Type mismatch between the returning type " - "(type={0}) and the expected type (type={1})", - returnTy, expectedTy)); - return success(); -} - //===----------------------------------------------------------------------===// // TensorDialect Methods. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -7,7 +7,6 @@ SparseTensorConversion.cpp SparseTensorPasses.cpp SparseTensorRewriting.cpp - SparseTensorStorageExpansion.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -54,8 +54,30 @@ return i; } +/// Flatten a list of operands that may contain sparse tensors. +static void flattenOperands(ValueRange operands, + SmallVectorImpl &flattened) { + // In case of + // sparse_tensor, c, sparse_tensor + // ==> + // memref ..., c, memref ... + for (auto operand : operands) { + if (auto cast = + dyn_cast(operand.getDefiningOp()); + cast && getSparseTensorEncoding(cast->getResultTypes()[0])) + // An unrealized_conversion_cast will be inserted by type converter to + // inter-mix the gap between 1:N conversion between sparse tensors and + // fields. In this case, take the operands in the cast and replace the + // sparse tensor output with the flattened type array. + flattened.append(cast.getOperands().begin(), cast.getOperands().end()); + else + flattened.push_back(operand); + } +} + /// Maps a sparse tensor type to the appropriate compounded buffers. -static Optional convertSparseTensorType(Type type) { +static Optional +convertSparseTensorType(Type type, SmallVectorImpl &fields) { auto enc = getSparseTensorEncoding(type); if (!enc) return llvm::None; @@ -86,7 +108,6 @@ // }; // unsigned rank = rType.getShape().size(); - SmallVector fields; // The dimSizes array. fields.push_back(MemRefType::get({rank}, indexType)); // Per-dimension storage. @@ -115,10 +136,7 @@ } // The values array. fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, eltType)); - // Sparse tensor storage (temporarily) lives in a tuple. This allows a - // simple 1:1 type conversion during codegen. A subsequent pass uses - // a 1:N type conversion to expand the tuple into its fields. - return TupleType::get(context, fields); + return success(); } // Returns field index of sparse tensor type for pointers/indices, when set. @@ -158,25 +176,6 @@ return -1; } -/// Returns field type in tuple at given index. -static Type getFieldType(Value tuple, unsigned field) { - return tuple.getType().cast().getType(field); -} - -/// Creates tuple get operation at given index. -static Value createTupleGet(OpBuilder &builder, Location loc, Value tuple, - unsigned field) { - Type indexType = builder.getIndexType(); - return builder.create(loc, getFieldType(tuple, field), tuple, - builder.getIntegerAttr(indexType, field)); -} - -/// Creates tuple. -static Value createTupleMake(OpBuilder &builder, Location loc, Type type, - ValueRange values) { - return builder.create(loc, type, values); -} - /// Create allocation operation. static Value createAllocation(OpBuilder &builder, Location loc, Type type, Value sz) { @@ -184,14 +183,15 @@ return builder.create(loc, memType, sz); } -/// Creates allocation tuple for sparse tensor type. +/// Creates allocation for each field in sparse tensor type. /// /// TODO: for efficiency, we will need heuristis to make educated guesses /// on the required final sizes; also, we will need an improved /// memory allocation scheme with capacity and reallocation /// -static Value createAllocTuple(OpBuilder &builder, Location loc, Type type, - ValueRange dynSizes) { +static void createAllocFields(OpBuilder &builder, Location loc, Type type, + ValueRange dynSizes, + SmallVectorImpl &fields) { auto enc = getSparseTensorEncoding(type); assert(enc); // Construct the basic types. @@ -202,10 +202,8 @@ Type idxType = idxWidth ? builder.getIntegerType(idxWidth) : indexType; Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType; Type eltType = rType.getElementType(); - // Build the allocation tuple, using heuristics for pre-allocation. auto shape = rType.getShape(); unsigned rank = shape.size(); - SmallVector fields; bool allDense = true; Value one = constantIndex(builder, loc, 1); Value linear = one; @@ -254,9 +252,6 @@ // In all other case, we resort to the heuristical initial value. Value valuesSz = allDense ? linear : heuristic; fields.push_back(createAllocation(builder, loc, eltType, valuesSz)); - // Construct tuple allocation. - Type tupleType = *convertSparseTensorType(type); - return createTupleMake(builder, loc, tupleType, fields); } /// Returns integral constant, if defined. @@ -270,14 +265,80 @@ // Codegen rules. //===----------------------------------------------------------------------===// -/// Sparse codegen rule for returns. +/// Sparse tensor storage conversion rule for returns. class SparseReturnConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + SmallVector flattened; + flattenOperands(adaptor.getOperands(), flattened); + // Create a return with the flattened value extracted from sparse tensors. + rewriter.replaceOpWithNewOp(op, flattened); + return success(); + } +}; + +/// Sparse tensor storage conversion rule for calls. +class SparseCallConverter : public OpConversionPattern { +public: + // The default CallOp converter can not handle 1:N type conversion. + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(func::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + // In case of: + // sparse_tensor, f, sparse_tensor = call @foo(...) + // ==> + // memref..., f, memref = call @foo(...) replace with + // cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor + SmallVector finalRetTy; + if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy))) + return failure(); + + // (1) Genereates new call with flattened return value. + SmallVector flattened; + flattenOperands(adaptor.getOperands(), flattened); + auto newCall = rewriter.create(loc, op.getCallee(), + finalRetTy, flattened); + // (2) Create cast operation for sparse tensor returns. + SmallVector castedRet; + // Tracks the offset of current return value (of the orignal call) + // relative to the new call (after sparse tensor flattening); + unsigned retOffset = 0; + // Temporal buffer to hold the flattened list of type for + // a sparse tensor. + SmallVector sparseFlat; + for (auto ret : op.getResults()) { + assert(retOffset < newCall.getNumResults()); + auto retType = ret.getType(); + if (failed(typeConverter->convertType(retType, sparseFlat))) + // This should never happen. + llvm_unreachable("Failed to convert type in sparse tensor codegen"); + + // Converted types can not be empty when the type conversion succeed. + assert(!sparseFlat.empty()); + if (sparseFlat.size() > 1) { + auto flatSize = sparseFlat.size(); + ValueRange sparseElem(iterator_range( + newCall.result_begin() + retOffset, + newCall.result_begin() + retOffset + flatSize)); + auto castOp = rewriter.create( + loc, TypeRange({retType}), sparseElem); + castedRet.push_back(castOp.getResult(0)); + retOffset += flatSize; + } else { + // If this is an 1:1 conversion, no need for casting. + castedRet.push_back(newCall.getResult(retOffset)); + retOffset++; + } + sparseFlat.clear(); + } + + assert(castedRet.size() == op.getNumResults()); + rewriter.replaceOp(op, castedRet); return success(); } }; @@ -306,10 +367,11 @@ } // Any other query can consult the dimSizes array at field 0 using, // accounting for the reordering applied to the sparse storage. - Value tuple = adaptor.getSource(); - Value dimSizes = createTupleGet(rewriter, loc, tuple, 0); + auto tuple = llvm::cast( + adaptor.getSource().getDefiningOp()); rewriter.replaceOpWithNewOp( - op, dimSizes, constantIndex(rewriter, loc, toStored(enc, *index))); + op, tuple.getInputs().front(), + constantIndex(rewriter, loc, toStored(enc, *index))); return success(); } }; @@ -345,10 +407,13 @@ return failure(); if (op.getCopy()) return rewriter.notifyMatchFailure(op, "tensor copy not implemented"); - // Construct allocation tuple. - Value tuple = createAllocTuple(rewriter, op->getLoc(), resType, - adaptor.getOperands()); - rewriter.replaceOp(op, tuple); + + // Construct allocation for each field. + Location loc = op.getLoc(); + SmallVector fields; + createAllocFields(rewriter, loc, resType, adaptor.getOperands(), fields); + rewriter.replaceOpWithNewOp( + op, TypeRange{resType}, fields); return success(); } }; @@ -364,86 +429,101 @@ auto enc = getSparseTensorEncoding(op.getTensor().getType()); if (!enc) return failure(); - // Replace the tuple deallocation with field deallocations. - Location loc = op->getLoc(); - Value tuple = adaptor.getTensor(); - for (unsigned i = 0, sz = tuple.getType().cast().size(); i < sz; - i++) { - Value mem = createTupleGet(rewriter, loc, tuple, i); - rewriter.create(loc, mem); - } + + // Replace the sparse tensor deallocation with field deallocations. + Location loc = op.getLoc(); + auto tuple = llvm::cast( + adaptor.getTensor().getDefiningOp()); + for (auto input : tuple.getInputs()) + // Deallocate every buffer used to store the sparse tensor handler. + rewriter.create(loc, input); + rewriter.eraseOp(op); return success(); } }; -/// Sparse codegen rule for pointer accesses. -class SparseToPointersConverter : public OpConversionPattern { +/// Sparse codegen rule for tensor rematerialization. +class SparseTensorLoadConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, + matchAndRewrite(LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Optional index = getConstantInt(adaptor.getOperands()[1]); - if (!index) - return failure(); - // Replace the requested pointer access with corresponding field. - Location loc = op->getLoc(); - Value tuple = adaptor.getTensor(); - unsigned i = getFieldIndex(op.getTensor().getType(), /*ptrDim=*/*index, -1); - rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i)); + if (op.getHasInserts()) { + // Finalize any pending insertions. + // TODO: implement + } + rewriter.replaceOp(op, adaptor.getOperands()); return success(); } }; -/// Sparse codegen rule for index accesses. -class SparseToIndicesConverter : public OpConversionPattern { +/// Base class for getter-like operations, e.g., to_indices, to_pointers. +template +class SparseGetterOpConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename SourceOp::Adaptor; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, + matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Optional index = getConstantInt(adaptor.getOperands()[1]); - if (!index) + // Replace the requested pointer access with corresponding field. + // The cast_op is inserted by type converter to intermix 1:N type + // conversion. + auto tuple = llvm::cast( + adaptor.getTensor().getDefiningOp()); + auto idx = Base::getIndexForOp(tuple, op); + if (!idx) + // Failed to get the index. return failure(); - // Replace the requested indices access with corresponding field. - Location loc = op->getLoc(); - Value tuple = adaptor.getTensor(); - unsigned i = getFieldIndex(op.getTensor().getType(), -1, /*idxDim=*/*index); - rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i)); + auto fields = tuple.getInputs(); + assert(*idx < fields.size()); + rewriter.replaceOp(op, fields[*idx]); return success(); } }; -/// Sparse codegen rule for value accesses. -class SparseToValuesConverter : public OpConversionPattern { +/// Sparse codegen rule for pointer accesses. +class SparseToPointersConverter + : public SparseGetterOpConverter { public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Replace the requested values access with corresponding field. - Location loc = op->getLoc(); - Value tuple = adaptor.getTensor(); - unsigned i = tuple.getType().cast().size() - 1; // last - rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i)); - return success(); + using SparseGetterOpConverter::SparseGetterOpConverter; + // Callback for SparseGetterOpConverter. + static Optional getIndexForOp(UnrealizedConversionCastOp /*tuple*/, + ToPointersOp op) { + Optional dim = getConstantInt(op.getDim()); + if (!dim) + return llvm::None; // variable dim + return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/*dim, -1); } }; -/// Sparse codegen rule for tensor rematerialization. -class SparseTensorLoadConverter : public OpConversionPattern { +/// Sparse codegen rule for index accesses. +class SparseToIndicesConverter + : public SparseGetterOpConverter { public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(LoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (op.getHasInserts()) { - // Finalize any pending insertions. - // TODO: implement - } - rewriter.replaceOp(op, adaptor.getOperands()); - return success(); + using SparseGetterOpConverter::SparseGetterOpConverter; + // Callback for SparseGetterOpConverter. + static Optional getIndexForOp(UnrealizedConversionCastOp /*tuple*/, + ToIndicesOp op) { + Optional dim = getConstantInt(op.getDim()); + if (!dim) + return llvm::None; // variable dim + return getFieldIndex(op.getTensor().getType(), -1, /*idxDim=*/*dim); + } +}; + +/// Sparse codegen rule for value accesses. +class SparseToValuesConverter + : public SparseGetterOpConverter { +public: + using SparseGetterOpConverter::SparseGetterOpConverter; + // Callback for SparseGetterOpConverter. + static Optional getIndexForOp(UnrealizedConversionCastOp tuple, + ToValuesOp /*op*/) { + // The last field holds the value buffer. + return tuple.getInputs().size() - 1; } }; @@ -466,9 +546,9 @@ /// the sparsification of linear algebra operations. void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add( - typeConverter, patterns.getContext()); + patterns.add(typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -24,7 +24,6 @@ #define GEN_PASS_DEF_SPARSIFICATIONPASS #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS #define GEN_PASS_DEF_SPARSETENSORCODEGEN -#define GEN_PASS_DEF_SPARSETENSORSTORAGEEXPANSION #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" } // namespace mlir @@ -154,9 +153,8 @@ RewritePatternSet patterns(ctx); SparseTensorTypeToBufferConverter converter; ConversionTarget target(*ctx); - // Almost everything in the sparse dialect must go! + // Everything in the sparse dialect must go! target.addIllegalDialect(); - target.addLegalOp(); // All dynamic rules below accept new function, call, return, and various // tensor and bufferization operations as legal output of the rewriting // provided that all sparse tensor types have been fully rewritten. @@ -181,53 +179,13 @@ target.addLegalDialect(); - // Populate with rules and apply rewriting rules. - populateFunctionOpInterfaceTypeConversionPattern(patterns, - converter); - populateCallOpTypeConversionPattern(patterns, converter); - scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, - target); - populateSparseTensorCodegenPatterns(converter, patterns); - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) - signalPassFailure(); - } -}; - -struct SparseTensorStorageExpansionPass - : public impl::SparseTensorStorageExpansionBase< - SparseTensorStorageExpansionPass> { - - SparseTensorStorageExpansionPass() = default; - SparseTensorStorageExpansionPass( - const SparseTensorStorageExpansionPass &pass) = default; - - void runOnOperation() override { - auto *ctx = &getContext(); - RewritePatternSet patterns(ctx); - SparseTensorStorageTupleExpander converter; - ConversionTarget target(*ctx); - // Now, everything in the sparse dialect must go! - target.addIllegalDialect(); - // All dynamic rules below accept new function, call, return. - target.addDynamicallyLegalOp([&](func::FuncOp op) { - return converter.isSignatureLegal(op.getFunctionType()); - }); - target.addDynamicallyLegalOp([&](func::CallOp op) { - return converter.isSignatureLegal(op.getCalleeType()); - }); - target.addDynamicallyLegalOp([&](func::ReturnOp op) { - return converter.isLegal(op.getOperandTypes()); - }); - // We generate UnrealizedConversionCastOp to intermix tuples and a - // list of types. target.addLegalOp(); // Populate with rules and apply rewriting rules. populateFunctionOpInterfaceTypeConversionPattern(patterns, converter); scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, target); - populateSparseTensorStorageExpansionPatterns(converter, patterns); + populateSparseTensorCodegenPatterns(converter, patterns); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); @@ -277,7 +235,3 @@ std::unique_ptr mlir::createSparseTensorCodegenPass() { return std::make_unique(); } - -std::unique_ptr mlir::createSparseTensorStorageExpansionPass() { - return std::make_unique(); -} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp +++ /dev/null @@ -1,218 +0,0 @@ -//===- SparseTensorStorageExpansion.cpp - Sparse tensor storage expansion ===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// The sparse tensor storage expansion pass expands the compound storage for -// sparse tensors (using tuple) to flattened SSA values. -// -//===----------------------------------------------------------------------===// - -#include "CodegenUtils.h" - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" -#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Transforms/DialectConversion.h" - -using namespace mlir; -using namespace mlir::sparse_tensor; - -namespace { - -//===----------------------------------------------------------------------===// -// Helper methods. -//===----------------------------------------------------------------------===// - -/// Expands sparse tensor storage tuple. -static Optional -convertSparseTensorStorageTuple(Type t, SmallVectorImpl &result) { - if (auto tuple = t.dyn_cast()) { - // Note that it does not handle nest tuples, but it is fine - // for sparse compiler as they will not be generated. - result.append(tuple.getTypes().begin(), tuple.getTypes().end()); - return success(); - } - return llvm::None; -} - -/// Flatten a list of operands that may contain tuples. -static void flattenOperands(ValueRange operands, - SmallVectorImpl &flattened) { - // In case of - // tuple, c, tuple - // ==> - // a, b, c, d, e - for (auto operand : operands) { - if (auto cast = - dyn_cast(operand.getDefiningOp()); - cast && cast->getResultTypes()[0].isa()) - // An unrealized_conversion_cast will be inserted by type converter to - // inter-mix the gap between 1:N conversion between tuple and types. - // In this case, take the operands in the cast and replace the tuple - // output with the flattened type array. - flattened.append(cast.getOperands().begin(), cast.getOperands().end()); - else - flattened.push_back(operand); - } -} -//===----------------------------------------------------------------------===// -// Conversion rules. -//===----------------------------------------------------------------------===// - -/// Sparse tensor storage conversion rule for sparse_tensor::storage. -class SparseStorageConversion : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(StorageOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Simply convert it to a unrealize_conversion_cast. - // We should guarantee that all uses of sparse_tensor.storage op will - // be eventually eliminated by accessing the flattened SSA values directly. - rewriter.replaceOpWithNewOp( - op, TypeRange{op.getType()}, adaptor.getInputs()); - return success(); - } -}; - -/// Sparse tensor storage conversion rule for sparse_tensor::storage_get. -class SparseStorageGetConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(StorageGetOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto castOp = - cast(adaptor.getStorage().getDefiningOp()); - uint64_t idx = op.getIdx().getZExtValue(); - assert(idx < castOp.getOperands().size()); - - rewriter.replaceOp(op, castOp.getOperand(idx)); - return success(); - } -}; - -/// Sparse tensor storage conversion rule for sparse_tensor::storage_set. -class SparseStorageSetConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(StorageSetOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto castOp = - cast(adaptor.getStorage().getDefiningOp()); - uint64_t idx = op.getIdx().getZExtValue(); - - SmallVector values(castOp.getOperands()); - assert(idx < values.size()); - - // Updates the corresponding element. - values[idx] = adaptor.getValue(); - rewriter.replaceOpWithNewOp( - op, TypeRange{op.getType()}, values); - return success(); - } -}; - -/// Sparse tensor storage conversion rule for returns. -class SparseStorageReturnConverter - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - SmallVector flattened; - flattenOperands(adaptor.getOperands(), flattened); - // Create a return with the flattened value extracted from tuple. - rewriter.replaceOpWithNewOp(op, flattened); - return success(); - } -}; - -/// Sparse tensor storage conversion rule for calls. -class SparseStorageCallConverter : public OpConversionPattern { -public: - // The default CallOp converter can not handle 1:N type conversion properly - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(func::CallOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - // In case of: - // tuple(a, b), f, tuple(c, d) = call @foo(...) - // ==> - // a, b, f, c, d = call @foo(...) - // cast(a, b)->tuple, f, cast(c,d)->tuple - SmallVector finalRetTy; - if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy))) - return failure(); - - // (1) Genereates new call with flattened return value. - SmallVector flattened; - flattenOperands(adaptor.getOperands(), flattened); - auto newCall = rewriter.create(loc, op.getCallee(), - finalRetTy, flattened); - - // (2) Create cast operation for tuple returns. - SmallVector castedRet; - // Tracks the offset of current return value (of the orignal call) - // relative to the new call (after tuple flattening); - unsigned retOffset = 0; - for (auto ret : op.getResults()) { - assert(retOffset < newCall.getNumResults()); - auto tupleRet = ret.getType().dyn_cast(); - if (tupleRet) { - auto tupleSize = tupleRet.size(); - // NOTE: The range is computed under the assumption of non-recursive - // tuple type. - ValueRange tupleElem(iterator_range( - newCall.result_begin() + retOffset, - newCall.result_begin() + retOffset + tupleSize)); - auto castOp = rewriter.create( - loc, TypeRange({tupleRet}), tupleElem); - castedRet.push_back(castOp.getResult(0)); - retOffset += tupleSize; - } else { - // If this not a tuple, simply add it into returned values. - castedRet.push_back(ret); - retOffset++; - } - } - - assert(castedRet.size() == op.getNumResults()); - rewriter.replaceOp(op, castedRet); - return success(); - } -}; - -} // namespace - -//===----------------------------------------------------------------------===// -// Sparse tensor storage expansion -//===----------------------------------------------------------------------===// - -mlir::SparseTensorStorageTupleExpander::SparseTensorStorageTupleExpander() { - addConversion([](Type type) { return type; }); - addConversion(convertSparseTensorStorageTuple); -} - -//===----------------------------------------------------------------------===// -// Public method for populating conversion rules. -//===----------------------------------------------------------------------===// - -/// Populates the given patterns list with conversion rules required -/// to expand compounded sparse tensor tuples. -void mlir::populateSparseTensorStorageExpansionPatterns( - TypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add(typeConverter, - patterns.getContext()); -} diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -1,5 +1,4 @@ -// RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize --cse | FileCheck %s --check-prefixes=CHECK,CHECK-CODEGEN -// RUN: mlir-opt %s --sparse-tensor-codegen --sparse-tensor-storage-expansion --canonicalize --cse | FileCheck %s --check-prefixes=CHECK,CHECK-STORAGE +// RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize --cse | FileCheck %s #SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], @@ -41,96 +40,114 @@ dimOrdering = affine_map<(i, j, k) -> (k, i, j)> }> -// CHECK-CODEGEN-LABEL: func @sparse_nop( -// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple, memref, memref, memref>) -// CHECK-CODEGEN: return %[[A]] : tuple, memref, memref, memref> -// -// CHECK-STORAGE-LABEL: func @sparse_nop( -// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref, -// CHECK-STORAGE-SAME: %[[A2:.*2]]: memref, -// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref) -// CHECK-STORAGE: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] : memref<1xindex>, memref, memref, memref +// CHECK-LABEL: func @sparse_nop( +// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref) +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] : memref<1xindex>, memref, memref, memref func.func @sparse_nop(%arg0: tensor) -> tensor { return %arg0 : tensor } -// CHECK-CODEGEN-LABEL: func @sparse_nop_cast( -// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple, memref, memref, memref>) -// CHECK-CODEGEN: return %[[A]] : tuple, memref, memref, memref> +// CHECK-LABEL: func @sparse_nop_multi_ret( +// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref<1xindex>, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: memref) -> +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]] +func.func @sparse_nop_multi_ret(%arg0: tensor, + %arg1: tensor) -> + (tensor, tensor) { + return %arg0, %arg1 : tensor, tensor +} + +// CHECK-LABEL: func @sparse_nop_call( +// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref<1xindex>, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: memref) +// CHECK: %[[T0:.*]]:8 = call @sparse_nop_multi_ret(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]]) +// CHECK: return %[[T0]]#0, %[[T0]]#1, %[[T0]]#2, %[[T0]]#3, %[[T0]]#4, %[[T0]]#5, %[[T0]]#6, %[[T0]]#7 +func.func @sparse_nop_call(%arg0: tensor, + %arg1: tensor) -> + (tensor, tensor) { + %1, %2 = call @sparse_nop_multi_ret(%arg0, %arg1) : + (tensor, tensor) -> + (tensor, tensor) + return %1, %2: tensor, tensor +} + // -// CHECK-STORAGE-LABEL: func @sparse_nop_cast( -// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref, -// CHECK-STORAGE-SAME: %[[A2:.*2]]: memref, -// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref) -// CHECK-STORAGE: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] : memref<1xindex>, memref, memref, memref +// CHECK-LABEL: func @sparse_nop_cast( +// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref) +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] : memref<1xindex>, memref, memref, memref func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor { %0 = tensor.cast %arg0 : tensor<64xf32, #SparseVector> to tensor return %0 : tensor } -// CHECK-CODEGEN-LABEL: func @sparse_nop_cast_3d( -// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple, memref>) -// CHECK-CODEGEN: return %[[A]] : tuple, memref> // -// CHECK-STORAGE-LABEL: func @sparse_nop_cast_3d( -// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref) -// CHECK-STORAGE: return %[[A0]], %[[A1]] : memref<3xindex>, memref +// CHECK-LABEL: func @sparse_nop_cast_3d( +// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref) +// CHECK: return %[[A0]], %[[A1]] : memref<3xindex>, memref func.func @sparse_nop_cast_3d(%arg0: tensor<10x20x30xf32, #Dense3D>) -> tensor { %0 = tensor.cast %arg0 : tensor<10x20x30xf32, #Dense3D> to tensor return %0 : tensor } -// CHECK-CODEGEN-LABEL: func @sparse_dense_2d( -// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple, memref>) // -// CHECK-STORAGE-LABEL: func @sparse_dense_2d( -// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref) { -// CHECK-STORAGE: return +// CHECK-LABEL: func @sparse_dense_2d( +// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref) { +// CHECK: return func.func @sparse_dense_2d(%arg0: tensor) { return } -// CHECK-CODEGEN-LABEL: func @sparse_row( -// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple, memref, memref, memref>) // -// CHECK-STORAGE-LABEL: func @sparse_row( -// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref, -// CHECK-STORAGE-SAME: %[[A2:.*2]]: memref, -// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref) { -// CHECK-STORAGE: return +// CHECK-LABEL: func @sparse_row( +// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref) { +// CHECK: return func.func @sparse_row(%arg0: tensor) { return } -// CHECK-CODEGEN-LABEL: func @sparse_csr( -// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple, memref, memref, memref>) // -// CHECK-STORAGE-LABEL: func @sparse_csr( -// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref, -// CHECK-STORAGE-SAME: %[[A2:.*2]]: memref, -// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref) { -// CHECK-STORAGE: return +// CHECK-LABEL: func @sparse_csr( +// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref) { +// CHECK: return func.func @sparse_csr(%arg0: tensor) { return } -// CHECK-CODEGEN-LABEL: func @sparse_dcsr( -// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple, memref, memref, memref, memref, memref>) // -// CHECK-STORAGE-LABEL: func @sparse_dcsr( -// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref, -// CHECK-STORAGE-SAME: %[[A2:.*2]]: memref, -// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref, -// CHECK-STORAGE-SAME: %[[A4:.*4]]: memref, -// CHECK-STORAGE-SAME: %[[A5:.*5]]: memref) { -// CHECK-STORAGE: return +// CHECK-LABEL: func @sparse_dcsr( +// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref) { +// CHECK: return func.func @sparse_dcsr(%arg0: tensor) { return } @@ -139,16 +156,12 @@ // Querying for dimension 1 in the tensor type can immediately // fold using the original static dimension sizes. // -// CHECK-CODEGEN-LABEL: func @sparse_dense_3d( -// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple, memref>) -// CHECK-CODEGEN: %[[C:.*]] = arith.constant 20 : index -// CHECK-CODEGEN: return %[[C]] : index // -// CHECK-STORAGE-LABEL: func @sparse_dense_3d( -// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref) -// CHECK-STORAGE: %[[C:.*]] = arith.constant 20 : index -// CHECK-STORAGE: return %[[C]] : index +// CHECK-LABEL: func @sparse_dense_3d( +// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref) +// CHECK: %[[C:.*]] = arith.constant 20 : index +// CHECK: return %[[C]] : index func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index { %c = arith.constant 1 : index %0 = tensor.dim %arg0, %c : tensor<10x20x30xf64, #Dense3D> @@ -160,103 +173,74 @@ // into querying for dimension 2 in the stored sparse tensor scheme, // since the latter honors the dimOrdering. // -// CHECK-CODEGEN-LABEL: func @sparse_dense_3d_dyn( -// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple, memref>) -// CHECK-CODEGEN: %[[C:.*]] = arith.constant 2 : index -// CHECK-CODEGEN: %[[F:.*]] = sparse_tensor.storage_get %[[A]][0] : tuple, memref> to memref<3xindex> -// CHECK-CODEGEN: %[[L:.*]] = memref.load %[[F]][%[[C]]] : memref<3xindex> -// CHECK-CODEGEN: return %[[L]] : index // -// CHECK-STORAGE-LABEL: func @sparse_dense_3d_dyn( -// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref) -// CHECK-STORAGE: %[[C:.*]] = arith.constant 2 : index -// CHECK-STORAGE: %[[L:.*]] = memref.load %[[A0]][%[[C]]] : memref<3xindex> -// CHECK-STORAGE: return %[[L]] : index +// CHECK-LABEL: func @sparse_dense_3d_dyn( +// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref) +// CHECK: %[[C:.*]] = arith.constant 2 : index +// CHECK: %[[L:.*]] = memref.load %[[A0]][%[[C]]] : memref<3xindex> +// CHECK: return %[[L]] : index func.func @sparse_dense_3d_dyn(%arg0: tensor) -> index { %c = arith.constant 1 : index %0 = tensor.dim %arg0, %c : tensor return %0 : index } -// CHECK-CODEGEN-LABEL: func @sparse_pointers_dcsr( -// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple, memref, memref, memref, memref, memref>) -// CHECK-CODEGEN: %[[F:.*]] = sparse_tensor.storage_get %[[A]][3] : tuple, memref, memref, memref, memref, memref> to memref -// CHECK-CODEGEN: return %[[F]] : memref // -// CHECK-STORAGE-LABEL: func @sparse_pointers_dcsr( -// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref, -// CHECK-STORAGE-SAME: %[[A2:.*2]]: memref, -// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref, -// CHECK-STORAGE-SAME: %[[A4:.*4]]: memref, -// CHECK-STORAGE-SAME: %[[A5:.*5]]: memref) -// CHECK-STORAGE: return %[[A3]] : memref +// CHECK-LABEL: func @sparse_pointers_dcsr( +// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref) +// CHECK: return %[[A3]] : memref func.func @sparse_pointers_dcsr(%arg0: tensor) -> memref { %c = arith.constant 1 : index %0 = sparse_tensor.pointers %arg0, %c : tensor to memref return %0 : memref } -// CHECK-CODEGEN-LABEL: func @sparse_indices_dcsr( -// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple, memref, memref, memref, memref, memref>) -// CHECK-CODEGEN: %[[F:.*]] = sparse_tensor.storage_get %[[A]][4] : tuple, memref, memref, memref, memref, memref> to memref -// CHECK-CODEGEN: return %[[F]] : memref // -// CHECK-STORAGE-LABEL: func @sparse_indices_dcsr( -// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref, -// CHECK-STORAGE-SAME: %[[A2:.*2]]: memref, -// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref, -// CHECK-STORAGE-SAME: %[[A4:.*4]]: memref, -// CHECK-STORAGE-SAME: %[[A5:.*5]]: memref) -// CHECK-STORAGE: return %[[A4]] : memref +// CHECK-LABEL: func @sparse_indices_dcsr( +// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref) +// CHECK: return %[[A4]] : memref func.func @sparse_indices_dcsr(%arg0: tensor) -> memref { %c = arith.constant 1 : index %0 = sparse_tensor.indices %arg0, %c : tensor to memref return %0 : memref } -// CHECK-CODEGEN-LABEL: func @sparse_values_dcsr( -// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple, memref, memref, memref, memref, memref>) -// CHECK-CODEGEN: %[[F:.*]] = sparse_tensor.storage_get %[[A]][5] : tuple, memref, memref, memref, memref, memref> to memref -// CHECK-CODEGEN: return %[[F]] : memref // -// CHECK-STORAGE-LABEL: func @sparse_values_dcsr( -// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref, -// CHECK-STORAGE-SAME: %[[A2:.*2]]: memref, -// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref, -// CHECK-STORAGE-SAME: %[[A4:.*4]]: memref, -// CHECK-STORAGE-SAME: %[[A5:.*5]]: memref) -// CHECK-STORAGE: return %[[A5]] : memref +// CHECK-LABEL: func @sparse_values_dcsr( +// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref) +// CHECK: return %[[A5]] : memref func.func @sparse_values_dcsr(%arg0: tensor) -> memref { %0 = sparse_tensor.values %arg0 : tensor to memref return %0 : memref } -// CHECK-CODEGEN-LABEL: func @sparse_dealloc_csr( -// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple, memref, memref, memref>) -// CHECK-CODEGEN: %[[F0:.*]] = sparse_tensor.storage_get %[[A]][0] : tuple, memref, memref, memref> to memref<2xindex> -// CHECK-CODEGEN: memref.dealloc %[[F0]] : memref<2xindex> -// CHECK-CODEGEN: %[[F1:.*]] = sparse_tensor.storage_get %[[A]][1] : tuple, memref, memref, memref> to memref -// CHECK-CODEGEN: memref.dealloc %[[F1]] : memref -// CHECK-CODEGEN: %[[F2:.*]] = sparse_tensor.storage_get %[[A]][2] : tuple, memref, memref, memref> to memref -// CHECK-CODEGEN: memref.dealloc %[[F2]] : memref -// CHECK-CODEGEN: %[[F3:.*]] = sparse_tensor.storage_get %[[A]][3] : tuple, memref, memref, memref> to memref -// CHECK-CODEGEN: memref.dealloc %[[F3]] : memref -// CHECK-CODEGEN: return // -// CHECK-STORAGE-LABEL: func @sparse_dealloc_csr( -// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref, -// CHECK-STORAGE-SAME: %[[A2:.*2]]: memref, -// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref) { -// CHECK-STORAGE: memref.dealloc %[[A0]] : memref<2xindex> -// CHECK-STORAGE: memref.dealloc %[[A1]] : memref -// CHECK-STORAGE: memref.dealloc %[[A2]] : memref -// CHECK-STORAGE: memref.dealloc %[[A3]] : memref -// CHECK-STORAGE: return +// CHECK-LABEL: func @sparse_dealloc_csr( +// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref) { +// CHECK: memref.dealloc %[[A0]] : memref<2xindex> +// CHECK: memref.dealloc %[[A1]] : memref +// CHECK: memref.dealloc %[[A2]] : memref +// CHECK: memref.dealloc %[[A3]] : memref +// CHECK: return func.func @sparse_dealloc_csr(%arg0: tensor) { bufferization.dealloc_tensor %arg0 : tensor return @@ -264,8 +248,7 @@ // CHECK-LABEL: func @sparse_alloc_csc( // CHECK-SAME: %[[A:.*]]: index) -> -// CHECK-CODEGEN-SAME: tuple, memref, memref, memref> -// CHECK-STORAGE-SAME: memref<2xindex>, memref, memref, memref +// CHECK-SAME: memref<2xindex>, memref, memref, memref // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index @@ -278,9 +261,7 @@ // CHECK: %[[T4:.*]] = memref.cast %[[T3]] : memref<1xindex> to memref // CHECK: %[[T5:.*]] = memref.alloc() : memref<1xf64> // CHECK: %[[T6:.*]] = memref.cast %[[T5]] : memref<1xf64> to memref -// CHECK-CODEGEN: %[[T:.*]] = sparse_tensor.storage(%[[T0]], %[[T2]], %[[T4]], %[[T6]]) -// CHECK-CODEGEN: return %[[T]] -// CHECK-STORAGE: return %[[T0]], %[[T2]], %[[T4]], %[[T6]] +// CHECK: return %[[T0]], %[[T2]], %[[T4]], %[[T6]] func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> { %0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSC> %1 = sparse_tensor.load %0 : tensor<10x?xf64, #CSC> @@ -288,8 +269,7 @@ } // CHECK-LABEL: func @sparse_alloc_3d() -> -// CHECK-CODEGEN-SAME: tuple, memref> -// CHECK-STORAGE-SAME: memref<3xindex>, memref +// CHECK-SAME: memref<3xindex>, memref // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index @@ -302,9 +282,7 @@ // CHECK: memref.store %[[C20]], %[[A0]][%[[C2]]] : memref<3xindex> // CHECK: %[[A:.*]] = memref.alloc() : memref<6000xf64> // CHECK: %[[A1:.*]] = memref.cast %[[A]] : memref<6000xf64> to memref -// CHECK-CODEGEN: %[[T:.*]] = sparse_tensor.storage(%[[A0]], %[[A1]]) -// CHECK-CODEGEN: return %[[T]] : tuple, memref> -// CHECK-STORAGE: return %[[A0]], %[[A1]] : memref<3xindex>, memref +// CHECK: return %[[A0]], %[[A1]] : memref<3xindex>, memref func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> { %0 = bufferization.alloc_tensor() : tensor<10x20x30xf64, #Dense3D> %1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D> diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -442,63 +442,3 @@ tensor<4x4xf64, #DC> to tensor<9x4xf64, #DC> return %0 : tensor<9x4xf64, #DC> } - -// ----- - -func.func @sparse_storage_new(%arg0: memref, %arg1: memref, %arg2: f64) -> - tuple, memref> { - // expected-error@+1{{The number of inputs is inconsistent with output}} - %0 = sparse_tensor.storage(%arg0, %arg1, %arg2) - : memref, memref, f64 to tuple, memref> - return %0 : tuple, memref> -} - -// ----- - -func.func @sparse_storage_new(%arg0: memref, %arg1: memref, %arg2: f64) -> - tuple, memref, f64> { - // expected-error@+1{{Type mismatch between}} - %0 = sparse_tensor.storage(%arg0, %arg1, %arg2) - : memref, memref, f64 to tuple, memref, f64> - return %0 : tuple, memref, f64> -} - -// ----- - -func.func @sparse_storage_get(%arg0: tuple, memref, f64>) -> memref { - // expected-error@+1{{Out-of-bound access}} - %0 = sparse_tensor.storage_get %arg0[3] - : tuple, memref, f64> to - memref - return %0 : memref -} - -// ----- - -func.func @sparse_storage_get(%arg0: tuple, memref, f64>) -> memref { - // expected-error@+1{{Type mismatch}} - %0 = sparse_tensor.storage_get %arg0[2] - : tuple, memref, f64> to - memref - return %0 : memref -} - -// ----- - -func.func @sparse_storage_set(%arg0: tuple, memref, f64>, %arg1: memref) -> tuple, memref, f64> { - // expected-error@+1{{Out-of-bound access}} - %0 = sparse_tensor.storage_set %arg0[3], %arg1 - : tuple, memref, f64>, memref to - tuple, memref, f64> - return %0 : tuple, memref, f64> -} - -// ----- - -func.func @sparse_storage_set(%arg0: tuple, memref, f64>, %arg1: memref) -> tuple, memref, f64> { - // expected-error@+1{{Type mismatch}} - %0 = sparse_tensor.storage_set %arg0[2], %arg1 - : tuple, memref, f64>, memref to - tuple, memref, f64> - return %0 : tuple, memref, f64> -} diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -314,50 +314,3 @@ tensor<4x4xf64, #SparseMatrix> to tensor<9x4xf64, #SparseMatrix> return %0 : tensor<9x4xf64, #SparseMatrix> } - -// ----- - - -// CHECK: func @sparse_storage_new( -// CHECK-SAME: %[[A0:.*0]]: memref, -// CHECK-SAME: %[[A1:.*1]]: memref, -// CHECK-SAME: %[[A2:.*]]: f64 -// CHECK: %[[TMP_0:.*]] = sparse_tensor.storage(%[[A0]], %[[A1]], %[[A2]]) -// CHECK: return %[[TMP_0]] : tuple, memref, f64> -func.func @sparse_storage_new(%arg0: memref, %arg1: memref, %arg2: f64) -> - tuple, memref, f64> { - %0 = sparse_tensor.storage(%arg0, %arg1, %arg2) - : memref, memref, f64 to tuple, memref, f64> - return %0 : tuple, memref, f64> -} - -// ----- - -// CHECK-LABEL: func @sparse_storage_get( -// CHECK-SAME: %[[A0:.*]]: tuple, memref, f64> -// CHECK: %[[TMP0:.*]] = sparse_tensor.storage_get %[[A0]][0] : -// CHECK-SAME: tuple, memref, f64> -// CHECK-SAME: to memref -// CHECK: return %[[TMP0]] : memref -func.func @sparse_storage_get(%arg0: tuple, memref, f64>) -> memref { - %0 = sparse_tensor.storage_get %arg0[0] - : tuple, memref, f64> to memref - return %0 : memref -} - -// ----- - -// CHECK-LABEL: func @sparse_storage_set( -// CHECK-SAME: %[[A0:.*]]: tuple, memref, f64>, -// CHECK-SAME: %[[A1:.*]]: memref -// CHECK: %[[TMP0:.*]] = sparse_tensor.storage_set %[[A0]][0], %[[A1]] : -// CHECK-SAME: tuple, memref, f64>, -// CHECK-SAME: memref -// CHECK-SAME: to tuple, memref, f64> -// CHECK: return %0 : tuple, memref, f64> -func.func @sparse_storage_set(%arg0: tuple, memref, f64>, %arg1: memref) -> tuple, memref, f64> { - %0 = sparse_tensor.storage_set %arg0[0], %arg1 - : tuple, memref, f64>, memref to - tuple, memref, f64> - return %0 : tuple, memref, f64> -} diff --git a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir deleted file mode 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir +++ /dev/null @@ -1,60 +0,0 @@ -// RUN: mlir-opt %s -sparse-tensor-storage-expansion -cse | FileCheck %s - -// CHECK-LABEL: func @sparse_storage_expand( -// CHECK-SAME: %[[TMP_arg0:.*0]]: memref, -// CHECK-SAME: %[[TMP_arg1:.*1]]: memref, -// CHECK-SAME: %[[TMP_arg2:.*]]: f64 -// CHECK return %[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]] -func.func @sparse_storage_expand(%arg0: tuple, memref, f64>) - -> tuple, memref, f64> { - return %arg0 : tuple, memref, f64> -} - -// CHECK-LABEL: func @call_sparse_storage_expand( -// CHECK-SAME: %[[TMP_arg0:.*0]]: memref, -// CHECK-SAME: %[[TMP_arg1:.*1]]: memref, -// CHECK-SAME: %[[TMP_arg2:.*]]: f64) -// CHECK: %[[TMP_0:.*]]:3 = call @sparse_storage_expand(%[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]]) -// CHECK: return %[[TMP_0]]#0, %[[TMP_0]]#1, %[[TMP_0]]#2 : memref, memref, f64 -func.func @call_sparse_storage_expand(%arg0: tuple, memref, f64>) - -> tuple, memref, f64> { - %1 = call @sparse_storage_expand(%arg0) : (tuple, memref, f64>) -> - tuple, memref, f64> - return %1 : tuple, memref, f64> -} - -// CHECK-LABEL: func @sparse_storage( -// CHECK-SAME: %[[TMP_arg0:.*0]]: memref, -// CHECK-SAME: %[[TMP_arg1:.*1]]: memref, -// CHECK-SAME: %[[TMP_arg2:.*2]]: memref) -// CHECK: return %[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]] -func.func @sparse_storage(%arg0: memref, %arg1: memref, %arg2: memref) - -> tuple, memref, memref> { - %1 = sparse_tensor.storage(%arg0, %arg1, %arg2) : memref, memref, memref to tuple, memref, memref> - return %1 : tuple, memref, memref> -} - -// CHECK-LABEL: func @sparse_storage_get( -// CHECK-SAME: %[[TMP_arg0:.*0]]: memref, -// CHECK-SAME: %[[TMP_arg1:.*1]]: memref, -// CHECK-SAME: %[[TMP_arg2:.*]]: f64) -// CHECK: return %[[TMP_arg0]] : memref -func.func @sparse_storage_get(%arg0: tuple, memref, f64>) -> memref { - %0 = sparse_tensor.storage_get %arg0[0] - : tuple, memref, f64> to memref - return %0 : memref -} - -// CHECK-LABEL: func @sparse_storage_set( -// CHECK-SAME: %[[TMP_arg0:.*0]]: memref, -// CHECK-SAME: %[[TMP_arg1:.*1]]: memref, -// CHECK-SAME: %[[TMP_arg2:.*]]: f64, -// CHECK-SAME: %[[TMP_arg3:.*]]: memref) -// CHECK: return %[[TMP_arg3]], %[[TMP_arg1]], %[[TMP_arg2]] : memref, memref, f64 -func.func @sparse_storage_set(%arg0: tuple, memref, f64>, - %arg1: memref) -> tuple, memref, f64> { - %0 = sparse_tensor.storage_set %arg0[0], %arg1 - : tuple, memref, f64>, memref to - tuple, memref, f64> - return %0 : tuple, memref, f64> -}