diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -217,10 +217,12 @@ target.addDynamicallyLegalOp([&](func::ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); + // We generate UnrealizedConversionCastOp to intermix tuples and a + // list of types. + target.addLegalOp(); // Populate with rules and apply rewriting rules. populateFunctionOpInterfaceTypeConversionPattern(patterns, converter); - populateCallOpTypeConversionPattern(patterns, converter); scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, target); populateSparseTensorStorageExpansionPatterns(converter, patterns); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp @@ -41,10 +41,69 @@ return llvm::None; } +/// Flatten a list of operands that may contain tuples. +static void flattenOperands(ValueRange operands, + SmallVectorImpl &flattened) { + // In case of + // tuple, c, tuple + // ==> + // a, b, c, d, e + for (auto operand : operands) { + if (auto cast = + dyn_cast(operand.getDefiningOp()); + cast && cast->getResultTypes()[0].isa()) + // An unrealized_conversion_cast will be inserted by type converter to + // inter-mix the gap between 1:N conversion between tuple and types. + // In this case, take the operands in the cast and replace the tuple + // output with the flattened type array. + flattened.append(cast.getOperands().begin(), cast.getOperands().end()); + else + flattened.push_back(operand); + } +} //===----------------------------------------------------------------------===// // Conversion rules. //===----------------------------------------------------------------------===// +/// Sparse tensor storage conversion rule for sparse_tensor::storage_get. +class SparseStorageGetConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(StorageGetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto castOp = + cast(adaptor.getStorage().getDefiningOp()); + uint64_t idx = op.getIdx().getZExtValue(); + assert(idx < castOp.getOperands().size()); + + rewriter.replaceOp(op, castOp.getOperand(idx)); + return success(); + } +}; + +/// Sparse tensor storage conversion rule for sparse_tensor::storage_set. +class SparseStorageSetConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(StorageSetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto castOp = + cast(adaptor.getStorage().getDefiningOp()); + uint64_t idx = op.getIdx().getZExtValue(); + + SmallVector values(castOp.getOperands()); + assert(idx < values.size()); + + // Updates the corresponding element. + values[idx] = adaptor.getValue(); + rewriter.replaceOpWithNewOp( + op, TypeRange{op.getType()}, values); + return success(); + } +}; + /// Sparse tensor storage conversion rule for returns. class SparseStorageReturnConverter : public OpConversionPattern { @@ -54,24 +113,69 @@ matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector flattened; - for (auto operand : adaptor.getOperands()) { - if (auto cast = - dyn_cast(operand.getDefiningOp()); - cast && cast->getResultTypes()[0].isa()) - // An unrealized_conversion_cast will be inserted by type converter to - // inter-mix the gap between 1:N conversion between tuple and types. - // In this case, take the operands in the cast and replace the tuple - // output with the flattened type array. - flattened.append(cast.getOperands().begin(), cast.getOperands().end()); - else - flattened.push_back(operand); - } + flattenOperands(adaptor.getOperands(), flattened); // Create a return with the flattened value extracted from tuple. rewriter.replaceOpWithNewOp(op, flattened); return success(); } }; +/// Sparse tensor storage conversion rule for calls. +class SparseStorageCallConverter : public OpConversionPattern { +public: + // The default CallOp converter can not handle 1:N type conversion properly + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(func::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + // In case of: + // tuple(a, b), f, tuple(c, d) = call @foo(...) + // ==> + // a, b, f, c, d = call @foo(...) + // cast(a, b)->tuple, f, cast(c,d)->tuple + SmallVector finalRetTy; + if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy))) + return failure(); + + // (1) Genereates new call with flattened return value. + SmallVector flattened; + flattenOperands(adaptor.getOperands(), flattened); + auto newCall = rewriter.create(loc, op.getCallee(), + finalRetTy, flattened); + + // (2) Create cast operation for tuple returns. + SmallVector castedRet; + // Tracks the offset of current return value (of the orignal call) + // relative to the new call (after tuple flattening); + unsigned retOffset = 0; + for (auto ret : op.getResults()) { + assert(retOffset < newCall.getNumResults()); + auto tupleRet = ret.getType().dyn_cast(); + if (tupleRet) { + auto tupleSize = tupleRet.size(); + // NOTE: The range is computed under the assumption of non-recursive + // tuple type. + ValueRange tupleElem(iterator_range( + newCall.result_begin() + retOffset, + newCall.result_begin() + retOffset + tupleSize)); + auto castOp = rewriter.create( + loc, TypeRange({tupleRet}), tupleElem); + castedRet.push_back(castOp.getResult(0)); + retOffset += tupleSize; + } else { + // If this not a tuple, simply add it into returned values. + castedRet.push_back(ret); + retOffset++; + } + } + + assert(castedRet.size() == op.getNumResults()); + rewriter.replaceOp(op, castedRet); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -91,6 +195,7 @@ /// to expand compounded sparse tensor tuples. void mlir::populateSparseTensorStorageExpansionPatterns( TypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add(typeConverter, - patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext()); } diff --git a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -sparse-tensor-storage-expansion | FileCheck %s +// RUN: mlir-opt %s -sparse-tensor-storage-expansion -cse | FileCheck %s // CHECK-LABEL: func @sparse_storage_expand( // CHECK-SAME: %[[TMP_arg0:.*0]]: memref, @@ -9,3 +9,41 @@ -> tuple, memref, f64> { return %arg0 : tuple, memref, f64> } + +// CHECK-LABEL: func @call_sparse_storage_expand( +// CHECK-SAME: %[[TMP_arg0:.*0]]: memref, +// CHECK-SAME: %[[TMP_arg1:.*1]]: memref, +// CHECK-SAME: %[[TMP_arg2:.*]]: f64) +// CHECK: %[[TMP_0:.*]]:3 = call @sparse_storage_expand(%[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]]) +// CHECK: return %[[TMP_0]]#0, %[[TMP_0]]#1, %[[TMP_0]]#2 : memref, memref, f64 +func.func @call_sparse_storage_expand(%arg0: tuple, memref, f64>) + -> tuple, memref, f64> { + %1 = call @sparse_storage_expand(%arg0) : (tuple, memref, f64>) -> + tuple, memref, f64> + return %1 : tuple, memref, f64> +} + +// CHECK-LABEL: func @sparse_storage_get( +// CHECK-SAME: %[[TMP_arg0:.*0]]: memref, +// CHECK-SAME: %[[TMP_arg1:.*1]]: memref, +// CHECK-SAME: %[[TMP_arg2:.*]]: f64) +// CHECK: return %[[TMP_arg0]] : memref +func.func @sparse_storage_get(%arg0: tuple, memref, f64>) -> memref { + %0 = sparse_tensor.storage_get %arg0[0] + : tuple, memref, f64> to memref + return %0 : memref +} + +// CHECK-LABEL: func @sparse_storage_set( +// CHECK-SAME: %[[TMP_arg0:.*0]]: memref, +// CHECK-SAME: %[[TMP_arg1:.*1]]: memref, +// CHECK-SAME: %[[TMP_arg2:.*]]: f64, +// CHECK-SAME: %[[TMP_arg3:.*]]: memref) +// CHECK: return %[[TMP_arg3]], %[[TMP_arg1]], %[[TMP_arg2]] : memref, memref, f64 +func.func @sparse_storage_set(%arg0: tuple, memref, f64>, + %arg1: memref) -> tuple, memref, f64> { + %0 = sparse_tensor.storage_set %arg0[0], %arg1 + : tuple, memref, f64>, memref to + tuple, memref, f64> + return %0 : tuple, memref, f64> +}