diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -31,6 +31,11 @@ namespace { +using FuncGeneratorType = + function_ref; + +static constexpr const char kInsertFuncNamePrefix[] = "_insert_"; + static constexpr uint64_t dimSizesIdx = 0; static constexpr uint64_t memSizesIdx = 1; static constexpr uint64_t fieldsIdx = 2; @@ -476,12 +481,23 @@ /// /// TODO: better unord/not-unique; also generalize, optimize, specialize! /// -static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp, - SmallVectorImpl &fields, - SmallVectorImpl &indices, Value value) { +static void genInsertBody(OpBuilder &builder, ModuleOp module, + func::FuncOp func, RankedTensorType rtp) { + OpBuilder::InsertionGuard insertionGuard(builder); + Block *entryBlock = func.addEntryBlock(); + builder.setInsertionPointToStart(entryBlock); + + Location loc = func.getLoc(); + ValueRange args = entryBlock->getArguments(); unsigned rank = rtp.getShape().size(); - assert(rank == indices.size()); - unsigned field = fieldsIdx; // start past header + + ValueRange tmp = args.drop_back(rank + 1); + SmallVector fields(tmp.begin(), tmp.end()); + tmp = args.take_back(rank + 1).drop_back(); + SmallVector indices(tmp.begin(), tmp.end()); + Value value = args.back(); + + unsigned field = fieldsIdx; // Start past header. Value pos = constantZero(builder, loc, builder.getIndexType()); // Generate code for every dimension. for (unsigned d = 0; d < rank; d++) { @@ -519,6 +535,76 @@ else genStore(builder, loc, value, fields[field++], pos); assert(fields.size() == field); + builder.create(loc, fields); +} + +/// Generates a call to a function to perform an insertion operation. If the +/// function doesn't exist yet, call `createFunc` to generate the function. +static void genInsertionCallHelper(OpBuilder &builder, RankedTensorType rtp, + SmallVectorImpl &fields, + SmallVectorImpl &indices, Value value, + func::FuncOp insertPoint, + StringRef namePrefix, + FuncGeneratorType createFunc) { + // The mangled name of the function has this format: + // _[C|S|D]___ + // __ + SmallString<32> nameBuffer; + llvm::raw_svector_ostream nameOstream(nameBuffer); + nameOstream << namePrefix; + unsigned rank = rtp.getShape().size(); + assert(rank == indices.size()); + for (unsigned d = 0; d < rank; d++) { + if (isCompressedDim(rtp, d)) { + nameOstream << "C_"; + } else if (isSingletonDim(rtp, d)) { + nameOstream << "S_"; + } else { + nameOstream << "D_"; + } + } + // Static dim sizes are used in the generated code while dynamic sizes are + // loaded from the dimSizes buffer. This is the reason for adding the shape + // to the function name. + for (auto d : rtp.getShape()) + nameOstream << d << "_"; + SparseTensorEncodingAttr enc = getSparseTensorEncoding(rtp); + // Permutation information is also used in generating insertion. + if (enc.getDimOrdering() && !enc.getDimOrdering().isIdentity()) + nameOstream << enc.getDimOrdering() << "_"; + nameOstream << rtp.getElementType() << "_"; + nameOstream << enc.getIndexBitWidth() << "_" << enc.getPointerBitWidth(); + + // Look up the function. + ModuleOp module = insertPoint->getParentOfType(); + MLIRContext *context = module.getContext(); + auto result = SymbolRefAttr::get(context, nameOstream.str()); + auto func = module.lookupSymbol(result.getAttr()); + + SmallVector operands(fields.begin(), fields.end()); + operands.append(indices.begin(), indices.end()); + operands.push_back(value); + Location loc = insertPoint.getLoc(); + + if (!func) { + // Create the function. + OpBuilder::InsertionGuard insertionGuard(builder); + builder.setInsertionPoint(insertPoint); + + func = builder.create( + loc, nameOstream.str(), + FunctionType::get(context, ValueRange(operands).getTypes(), + ValueRange(fields).getTypes())); + func.setPrivate(); + createFunc(builder, module, func, rtp); + } + + // Generate a call to perform the insertion and update `fields` with values + // returned from the call. + func::CallOp call = builder.create(loc, func, operands); + for (size_t i = 0; i < fields.size(); i++) { + fields[i] = call.getResult(i); + } } /// Generations insertion finalization code. @@ -865,7 +951,9 @@ Value value = genLoad(rewriter, loc, values, index); indices.push_back(index); // TODO: faster for subsequent insertions? - genInsert(rewriter, loc, dstType, fields, indices, value); + auto insertPoint = op->template getParentOfType(); + genInsertionCallHelper(rewriter, dstType, fields, indices, value, + insertPoint, kInsertFuncNamePrefix, genInsertBody); genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, index); genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, index); @@ -899,7 +987,10 @@ SmallVector indices(adaptor.getIndices()); // Generate insertion. Value value = adaptor.getValue(); - genInsert(rewriter, op->getLoc(), dstType, fields, indices, value); + auto insertPoint = op->template getParentOfType(); + genInsertionCallHelper(rewriter, dstType, fields, indices, value, + insertPoint, kInsertFuncNamePrefix, genInsertBody); + // Replace operation with resulting memrefs. rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), dstType, fields)); return success(); diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -381,6 +381,17 @@ return %added : memref } +// CHECK-LABEL: func.func private @_insert_C_100_f64_0_0( +// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: index, +// CHECK-SAME: %[[A6:.*6]]: f64) +// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref, f64 +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[PV]] +// // CHECK-LABEL: func @sparse_compression_1d( // CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, // CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, @@ -396,18 +407,19 @@ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref -// CHECK: %[[R:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] iter_args(%[[P0:.*]] = %[[A3]], %[[P1:.*]] = %[[A4]]) -> (memref, memref) { +// CHECK: %[[R:.*]]:5 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] +// CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<1xindex>, memref<3xindex>, memref, memref, memref) { // CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref // CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref -// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[VAL]] {idx = 2 : index} : memref<3xindex>, memref, f64 +// CHECK: %[[C:.*]]:5 = func.call @_insert_C_100_f64_0_0(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[INDEX]], %[[VAL]]) // CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref // CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref -// CHECK: scf.yield %{{.*}}, %[[PV]] : memref, memref +// CHECK: scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<1xindex>, memref<3xindex>, memref, memref, memref // CHECK: } // CHECK: memref.dealloc %[[A5]] : memref // CHECK: memref.dealloc %[[A6]] : memref // CHECK: memref.dealloc %[[A7]] : memref -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[R]]#0, %[[R]]#1 +// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 // CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>, %values: memref, @@ -420,6 +432,18 @@ return %1 : tensor<100xf64, #SV> } +// CHECK-LABEL: func.func private @_insert_D_C_8_8_f64_64_32( +// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: index, +// CHECK-SAME: %[[A6:.*6]]: index, +// CHECK-SAME: %[[A7:.*7]]: f64) +// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A7]] {idx = 2 : index} : memref<3xindex>, memref, f64 +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[PV]] +// // CHECK-LABEL: func @sparse_compression( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, // CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, @@ -436,18 +460,19 @@ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref -// CHECK: %[[R:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] iter_args(%[[P0:.*]] = %[[A3]], %[[P1:.*]] = %[[A4]]) -> (memref, memref) { +// CHECK: %[[R:.*]]:5 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] +// CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { // CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref // CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref -// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[VAL]] {idx = 2 : index} : memref<3xindex>, memref, f64 +// CHECK: %[[C:.*]]:5 = func.call @_insert_D_C_8_8_f64_64_32(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[A9]], %[[INDEX]], %[[VAL]]) // CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref // CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref -// CHECK: scf.yield %{{.*}}, %[[PV]] : memref, memref +// CHECK: scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref // CHECK: } // CHECK: memref.dealloc %[[A5]] : memref // CHECK: memref.dealloc %[[A6]] : memref // CHECK: memref.dealloc %[[A7]] : memref -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[R]]#0, %[[R]]#1 +// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 // CHECK-SAME: memref<2xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>, %values: memref, @@ -461,6 +486,18 @@ return %1 : tensor<8x8xf64, #CSR> } +// CHECK-LABEL: func.func private @_insert_D_C_8_8_f64_0_0( +// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: index, +// CHECK-SAME: %[[A6:.*6]]: index, +// CHECK-SAME: %[[A7:.*7]]: f64) +// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A7]] {idx = 2 : index} : memref<3xindex>, memref, f64 +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[PV]] +// // CHECK-LABEL: func @sparse_compression_unordered( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, // CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, @@ -477,18 +514,19 @@ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-NOT: sparse_tensor.sort -// CHECK: %[[R:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] iter_args(%[[P0:.*]] = %[[A3]], %[[P1:.*]] = %[[A4]]) -> (memref, memref) { +// CHECK: %[[R:.*]]:5 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] +// CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { // CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref // CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref -// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[VAL]] {idx = 2 : index} : memref<3xindex>, memref, f64 +// CHECK: %[[C:.*]]:5 = func.call @_insert_D_C_8_8_f64_0_0(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[A9]], %[[INDEX]], %[[VAL]]) // CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref // CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref -// CHECK: scf.yield %{{.*}}, %[[PV]] : memref, memref +// CHECK: scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref // CHECK: } // CHECK: memref.dealloc %[[A5]] : memref // CHECK: memref.dealloc %[[A6]] : memref // CHECK: memref.dealloc %[[A7]] : memref -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[R]]#0, %[[R]]#1 +// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 // CHECK-SAME: memref<2xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>, %values: memref, @@ -502,7 +540,7 @@ return %1 : tensor<8x8xf64, #UCSR> } -// CHECK-LABEL: func @sparse_insert( +// CHECK-LABEL: func.func private @_insert_C_128_f64_0_0( // CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, // CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, // CHECK-SAME: %[[A2:.*2]]: memref, @@ -512,6 +550,16 @@ // CHECK-SAME: %[[A6:.*6]]: f64) // CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref, f64 // CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[P]] : +// CHECK: func @sparse_insert( +// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: index, +// CHECK-SAME: %[[A6:.*6]]: f64) +// CHECK: %[[R:.*]]:5 = call @_insert_C_128_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]]) +// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 // CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> { %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV> @@ -519,7 +567,7 @@ return %1 : tensor<128xf64, #SV> } -// CHECK-LABEL: func @sparse_insert_typed( +// CHECK-LABEL: func.func private @_insert_C_128_f64_64_32( // CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, // CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, // CHECK-SAME: %[[A2:.*2]]: memref, @@ -529,6 +577,16 @@ // CHECK-SAME: %[[A6:.*6]]: f64) // CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref, f64 // CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[P]] : +// CHECK: func @sparse_insert_typed( +// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: index, +// CHECK-SAME: %[[A6:.*6]]: f64) +// CHECK: %[[R:.*]]:5 = call @_insert_C_128_f64_64_32(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]]) +// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 // CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> { %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector> @@ -547,4 +605,4 @@ func.func @sparse_nop_convert(%arg0: tensor<32xf32, #SparseVector>) -> tensor { %0 = sparse_tensor.convert %arg0 : tensor<32xf32, #SparseVector> to tensor return %0 : tensor -} +} \ No newline at end of file diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir @@ -12,32 +12,68 @@ // // Computes C = A x B with all matrices sparse (SpMSpM) in CSR. // +// CHECK-LABEL: func.func private @_insert_D_C_4_4_f64_0_0( +// CHECK-SAME: %[[VAL_0:.*]]: memref<2xindex>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<3xindex>, +// CHECK-SAME: %[[VAL_2:[^ ]+]]: memref, +// CHECK-SAME: %[[VAL_3:.*]]: memref, +// CHECK-SAME: %[[VAL_4:.*]]: memref, +// CHECK-SAME: %[[VAL_5:[^ ]+]]: index, +// CHECK-SAME: %[[VAL_6:.*]]: index, +// CHECK-SAME: %[[VAL_7:.*]]: f64) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { +// CHECK: %[[VAL_8:.*]] = arith.constant false +// CHECK: %[[VAL_9:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_5]], %[[VAL_9]] : index +// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_9]]] : memref<3xindex> +// CHECK: %[[VAL_14:.*]] = arith.subi %[[VAL_12]], %[[VAL_9]] : index +// CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_12]] : index +// CHECK: %[[VAL_16:.*]] = scf.if %[[VAL_15]] -> (i1) { +// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_14]]] : memref +// CHECK: %[[VAL_18:.*]] = arith.cmpi eq, %[[VAL_17]], %[[VAL_6]] : index +// CHECK: scf.yield %[[VAL_18]] : i1 +// CHECK: } else { +// CHECK: memref.store %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref +// CHECK: scf.yield %[[VAL_8]] : i1 +// CHECK: } +// CHECK: %[[VAL_19:.*]] = scf.if %[[VAL_20:.*]] -> (memref) { +// CHECK: scf.yield %[[VAL_3]] : memref +// CHECK: } else { +// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_13]], %[[VAL_9]] : index +// CHECK: memref.store %[[VAL_21]], %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref +// CHECK: %[[VAL_22:.*]] = sparse_tensor.push_back %[[VAL_1]], %[[VAL_3]], %[[VAL_6]] {idx = 1 : index} : memref<3xindex>, memref, index +// CHECK: scf.yield %[[VAL_22]] : memref +// CHECK: } +// CHECK: %[[VAL_23:.*]] = sparse_tensor.push_back %[[VAL_1]], %[[VAL_4]], %[[VAL_7]] {idx = 2 : index} : memref<3xindex>, memref, f64 +// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_24:.*]], %[[VAL_23]] : memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK: } + // CHECK-LABEL: func.func @matmul( -// CHECK-SAME: %[[VAL_0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[VAL_1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[VAL_2:.*2]]: memref, -// CHECK-SAME: %[[VAL_3:.*3]]: memref, -// CHECK-SAME: %[[VAL_4:.*4]]: memref, -// CHECK-SAME: %[[VAL_5:.*5]]: memref<2xindex>, -// CHECK-SAME: %[[VAL_6:.*6]]: memref<3xindex>, -// CHECK-SAME: %[[VAL_7:.*7]]: memref, -// CHECK-SAME: %[[VAL_8:.*8]]: memref, -// CHECK-SAME: %[[VAL_9:.*9]]: memref) -// CHECK-SAME: -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_14:.*]] = arith.constant false -// CHECK-DAG: %[[VAL_15:.*]] = arith.constant true -// CHECK-DAG: %[[VAL_16:.*]] = memref.alloc() : memref<2xindex> -// CHECK-DAG: %[[VAL_17:.*]] = memref.alloc() : memref<3xindex> -// CHECK-DAG: %[[VAL_18:.*]] = memref.alloc() : memref<16xindex> -// CHECK-DAG: %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<16xindex> to memref -// CHECK-DAG: %[[VAL_20:.*]] = memref.alloc() : memref<16xindex> -// CHECK-DAG: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<16xindex> to memref -// CHECK-DAG: %[[VAL_22:.*]] = memref.alloc() : memref<16xf64> -// CHECK-DAG: %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<16xf64> to memref +// CHECK-SAME: %[[VAL_0:.*0]]: memref<2xindex>, +// CHECK-SAME: %[[VAL_1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[VAL_2:.*2]]: memref, +// CHECK-SAME: %[[VAL_3:.*3]]: memref, +// CHECK-SAME: %[[VAL_4:.*4]]: memref, +// CHECK-SAME: %[[VAL_5:.*5]]: memref<2xindex>, +// CHECK-SAME: %[[VAL_6:.*6]]: memref<3xindex>, +// CHECK-SAME: %[[VAL_7:.*7]]: memref, +// CHECK-SAME: %[[VAL_8:.*8]]: memref, +// CHECK-SAME: %[[VAL_9:.*9]]: memref) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { +// CHECK: %[[VAL_10:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[VAL_12:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_13:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_14:.*]] = arith.constant false +// CHECK: %[[VAL_15:.*]] = arith.constant true +// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<2xindex> +// CHECK: %[[VAL_17:.*]] = memref.alloc() : memref<3xindex> +// CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<16xindex> to memref +// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<16xindex> to memref +// CHECK: %[[VAL_22:.*]] = memref.alloc() : memref<16xf64> +// CHECK: %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<16xf64> to memref // CHECK: linalg.fill ins(%[[VAL_12]] : index) outs(%[[VAL_17]] : memref<3xindex>) // CHECK: memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_12]]] : memref<2xindex> // CHECK: memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_13]]] : memref<2xindex> @@ -49,84 +85,61 @@ // CHECK: %[[VAL_29:.*]] = memref.cast %[[VAL_28]] : memref<4xindex> to memref // CHECK: linalg.fill ins(%[[VAL_11]] : f64) outs(%[[VAL_26]] : memref<4xf64>) // CHECK: linalg.fill ins(%[[VAL_14]] : i1) outs(%[[VAL_27]] : memref<4xi1>) -// CHECK: %[[VAL_30:.*]]:2 = scf.for %[[VAL_31:.*]] = %[[VAL_12]] to %[[VAL_10]] step %[[VAL_13]] iter_args(%[[VAL_32:.*]] = %[[VAL_21]], %[[VAL_33:.*]] = %[[VAL_23]]) -> (memref, memref) { -// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_31]]] : memref -// CHECK: %[[VAL_35:.*]] = arith.addi %[[VAL_31]], %[[VAL_13]] : index -// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_35]]] : memref -// CHECK: %[[VAL_37:.*]] = scf.for %[[VAL_38:.*]] = %[[VAL_34]] to %[[VAL_36]] step %[[VAL_13]] iter_args(%[[VAL_39:.*]] = %[[VAL_12]]) -> (index) { -// CHECK: %[[VAL_40:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_38]]] : memref -// CHECK: %[[VAL_41:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_38]]] : memref -// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_40]]] : memref -// CHECK: %[[VAL_43:.*]] = arith.addi %[[VAL_40]], %[[VAL_13]] : index -// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_43]]] : memref -// CHECK: %[[VAL_45:.*]] = scf.for %[[VAL_46:.*]] = %[[VAL_42]] to %[[VAL_44]] step %[[VAL_13]] iter_args(%[[VAL_47:.*]] = %[[VAL_39]]) -> (index) { -// CHECK: %[[VAL_48:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_46]]] : memref -// CHECK: %[[VAL_49:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_48]]] : memref<4xf64> -// CHECK: %[[VAL_50:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_46]]] : memref -// CHECK: %[[VAL_51:.*]] = arith.mulf %[[VAL_41]], %[[VAL_50]] : f64 -// CHECK: %[[VAL_52:.*]] = arith.addf %[[VAL_49]], %[[VAL_51]] : f64 -// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_48]]] : memref<4xi1> -// CHECK: %[[VAL_54:.*]] = arith.cmpi eq, %[[VAL_53]], %[[VAL_14]] : i1 -// CHECK: %[[VAL_55:.*]] = scf.if %[[VAL_54]] -> (index) { -// CHECK: memref.store %[[VAL_15]], %[[VAL_27]]{{\[}}%[[VAL_48]]] : memref<4xi1> -// CHECK: memref.store %[[VAL_48]], %[[VAL_28]]{{\[}}%[[VAL_47]]] : memref<4xindex> -// CHECK: %[[VAL_56:.*]] = arith.addi %[[VAL_47]], %[[VAL_13]] : index -// CHECK: scf.yield %[[VAL_56]] : index +// CHECK: %[[VAL_30:.*]]:5 = scf.for %[[VAL_31:.*]] = %[[VAL_12]] to %[[VAL_10]] step %[[VAL_13]] iter_args(%[[VAL_32:.*]] = %[[VAL_16]], %[[VAL_33:.*]] = %[[VAL_17]], %[[VAL_34:.*]] = %[[VAL_25]], %[[VAL_35:.*]] = %[[VAL_21]], %[[VAL_36:.*]] = %[[VAL_23]]) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { +// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_31]]] : memref +// CHECK: %[[VAL_38:.*]] = arith.addi %[[VAL_31]], %[[VAL_13]] : index +// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_38]]] : memref +// CHECK: %[[VAL_40:.*]] = scf.for %[[VAL_41:.*]] = %[[VAL_37]] to %[[VAL_39]] step %[[VAL_13]] iter_args(%[[VAL_42:.*]] = %[[VAL_12]]) -> (index) { +// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_41]]] : memref +// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_41]]] : memref +// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_43]]] : memref +// CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_43]], %[[VAL_13]] : index +// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_46]]] : memref +// CHECK: %[[VAL_48:.*]] = scf.for %[[VAL_49:.*]] = %[[VAL_45]] to %[[VAL_47]] step %[[VAL_13]] iter_args(%[[VAL_50:.*]] = %[[VAL_42]]) -> (index) { +// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_49]]] : memref +// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_51]]] : memref<4xf64> +// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_49]]] : memref +// CHECK: %[[VAL_54:.*]] = arith.mulf %[[VAL_44]], %[[VAL_53]] : f64 +// CHECK: %[[VAL_55:.*]] = arith.addf %[[VAL_52]], %[[VAL_54]] : f64 +// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_51]]] : memref<4xi1> +// CHECK: %[[VAL_57:.*]] = arith.cmpi eq, %[[VAL_56]], %[[VAL_14]] : i1 +// CHECK: %[[VAL_58:.*]] = scf.if %[[VAL_57]] -> (index) { +// CHECK: memref.store %[[VAL_15]], %[[VAL_27]]{{\[}}%[[VAL_51]]] : memref<4xi1> +// CHECK: memref.store %[[VAL_51]], %[[VAL_28]]{{\[}}%[[VAL_50]]] : memref<4xindex> +// CHECK: %[[VAL_59:.*]] = arith.addi %[[VAL_50]], %[[VAL_13]] : index +// CHECK: scf.yield %[[VAL_59]] : index // CHECK: } else { -// CHECK: scf.yield %[[VAL_47]] : index +// CHECK: scf.yield %[[VAL_50]] : index // CHECK: } -// CHECK: memref.store %[[VAL_52]], %[[VAL_26]]{{\[}}%[[VAL_48]]] : memref<4xf64> -// CHECK: scf.yield %[[VAL_57:.*]] : index -// CHECK: } -// CHECK: scf.yield %[[VAL_58:.*]] : index -// CHECK: } -// CHECK: sparse_tensor.sort %[[VAL_59:.*]], %[[VAL_29]] : memref -// CHECK: %[[VAL_60:.*]]:2 = scf.for %[[VAL_61:.*]] = %[[VAL_12]] to %[[VAL_59]] step %[[VAL_13]] iter_args(%[[VAL_62:.*]] = %[[VAL_32]], %[[VAL_63:.*]] = %[[VAL_33]]) -> (memref, memref) { -// CHECK: %[[VAL_64:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_61]]] : memref<4xindex> -// CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_64]]] : memref<4xf64> -// CHECK: %[[VAL_66:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_31]]] : memref -// CHECK: %[[VAL_67:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_35]]] : memref -// CHECK: %[[VAL_68:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_13]]] : memref<3xindex> -// CHECK: %[[VAL_69:.*]] = arith.subi %[[VAL_67]], %[[VAL_13]] : index -// CHECK: %[[VAL_70:.*]] = arith.cmpi ult, %[[VAL_66]], %[[VAL_67]] : index -// CHECK: %[[VAL_71:.*]] = scf.if %[[VAL_70]] -> (i1) { -// CHECK: %[[VAL_72:.*]] = memref.load %[[VAL_62]]{{\[}}%[[VAL_69]]] : memref -// CHECK: %[[VAL_73:.*]] = arith.cmpi eq, %[[VAL_72]], %[[VAL_64]] : index -// CHECK: scf.yield %[[VAL_73]] : i1 -// CHECK: } else { -// CHECK: memref.store %[[VAL_68]], %[[VAL_25]]{{\[}}%[[VAL_31]]] : memref -// CHECK: scf.yield %[[VAL_14]] : i1 -// CHECK: } -// CHECK: %[[VAL_74:.*]] = scf.if %[[VAL_75:.*]] -> (memref) { -// CHECK: scf.yield %[[VAL_62]] : memref -// CHECK: } else { -// CHECK: %[[VAL_76:.*]] = arith.addi %[[VAL_68]], %[[VAL_13]] : index -// CHECK: memref.store %[[VAL_76]], %[[VAL_25]]{{\[}}%[[VAL_35]]] : memref -// CHECK: %[[VAL_77:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_62]], %[[VAL_64]] {idx = 1 : index} : memref<3xindex>, memref, index -// CHECK: scf.yield %[[VAL_77]] : memref -// CHECK: } -// CHECK: %[[VAL_78:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_63]], %[[VAL_65]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: memref.store %[[VAL_11]], %[[VAL_26]]{{\[}}%[[VAL_64]]] : memref<4xf64> -// CHECK: memref.store %[[VAL_14]], %[[VAL_27]]{{\[}}%[[VAL_64]]] : memref<4xi1> -// CHECK: scf.yield %[[VAL_79:.*]], %[[VAL_78]] : memref, memref +// CHECK: memref.store %[[VAL_55]], %[[VAL_26]]{{\[}}%[[VAL_51]]] : memref<4xf64> +// CHECK: scf.yield %[[VAL_60:.*]] : index +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: sparse_tensor.sort %[[VAL_62:.*]], %[[VAL_29]] : memref +// CHECK: %[[VAL_63:.*]]:5 = scf.for %[[VAL_64:.*]] = %[[VAL_12]] to %[[VAL_62]] step %[[VAL_13]] iter_args(%[[VAL_65:.*]] = %[[VAL_32]], %[[VAL_66:.*]] = %[[VAL_33]], %[[VAL_67:.*]] = %[[VAL_34]], %[[VAL_68:.*]] = %[[VAL_35]], %[[VAL_69:.*]] = %[[VAL_36]]) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { +// CHECK: %[[VAL_70:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_64]]] : memref<4xindex> +// CHECK: %[[VAL_71:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64> +// CHECK: %[[VAL_72:.*]]:5 = func.call @_insert_D_C_4_4_f64_0_0(%[[VAL_65]], %[[VAL_66]], %[[VAL_67]], %[[VAL_68]], %[[VAL_69]], %[[VAL_31]], %[[VAL_70]], %[[VAL_71]]) : (memref<2xindex>, memref<3xindex>, memref, memref, memref, index, index, f64) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) +// CHECK: memref.store %[[VAL_11]], %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64> +// CHECK: memref.store %[[VAL_14]], %[[VAL_27]]{{\[}}%[[VAL_70]]] : memref<4xi1> +// CHECK: scf.yield %[[VAL_72]]#0, %[[VAL_72]]#1, %[[VAL_72]]#2, %[[VAL_72]]#3, %[[VAL_72]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref // CHECK: } -// CHECK: scf.yield %[[VAL_80:.*]]#0, %[[VAL_80]]#1 : memref, memref -// CHECK: } +// CHECK: scf.yield %[[VAL_73:.*]]#0, %[[VAL_73]]#1, %[[VAL_73]]#2, %[[VAL_73]]#3, %[[VAL_73]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK: } {"Emitted from" = "linalg.generic"} // CHECK: memref.dealloc %[[VAL_26]] : memref<4xf64> // CHECK: memref.dealloc %[[VAL_27]] : memref<4xi1> // CHECK: memref.dealloc %[[VAL_28]] : memref<4xindex> -// CHECK: %[[VAL_81:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_12]]] : memref<3xindex> -// CHECK: %[[VAL_82:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_12]]] : memref -// CHECK: %[[VAL_83:.*]] = scf.for %[[VAL_84:.*]] = %[[VAL_13]] to %[[VAL_81]] step %[[VAL_13]] iter_args(%[[VAL_85:.*]] = %[[VAL_82]]) -> (index) { -// CHECK: %[[VAL_86:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_84]]] : memref -// CHECK: %[[VAL_87:.*]] = arith.cmpi eq, %[[VAL_86]], %[[VAL_12]] : index -// CHECK: %[[VAL_88:.*]] = arith.select %[[VAL_87]], %[[VAL_85]], %[[VAL_86]] : index -// CHECK: scf.if %[[VAL_87]] { -// CHECK: memref.store %[[VAL_85]], %[[VAL_25]]{{\[}}%[[VAL_84]]] : memref +// CHECK: %[[VAL_74:.*]] = memref.load %[[VAL_75:.*]]#1{{\[}}%[[VAL_12]]] : memref<3xindex> +// CHECK: %[[VAL_76:.*]] = memref.load %[[VAL_75]]#2{{\[}}%[[VAL_12]]] : memref +// CHECK: %[[VAL_77:.*]] = scf.for %[[VAL_78:.*]] = %[[VAL_13]] to %[[VAL_74]] step %[[VAL_13]] iter_args(%[[VAL_79:.*]] = %[[VAL_76]]) -> (index) { +// CHECK: %[[VAL_80:.*]] = memref.load %[[VAL_75]]#2{{\[}}%[[VAL_78]]] : memref +// CHECK: %[[VAL_81:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_12]] : index +// CHECK: %[[VAL_82:.*]] = arith.select %[[VAL_81]], %[[VAL_79]], %[[VAL_80]] : index +// CHECK: scf.if %[[VAL_81]] { +// CHECK: memref.store %[[VAL_79]], %[[VAL_75]]#2{{\[}}%[[VAL_78]]] : memref // CHECK: } -// CHECK: scf.yield %[[VAL_88]] : index +// CHECK: scf.yield %[[VAL_82]] : index // CHECK: } -// CHECK: return %[[VAL_16]], %[[VAL_17]], %[[VAL_25]], %[[VAL_89:.*]]#0, %[[VAL_89]]#1 : memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK: return %[[VAL_75]]#0, %[[VAL_75]]#1, %[[VAL_75]]#2, %[[VAL_75]]#3, %[[VAL_75]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref // CHECK: } func.func @matmul(%A: tensor<4x8xf64, #CSR>, %B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> {