diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_mttkrp.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_mttkrp.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_mttkrp.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_mttkrp.mlir @@ -75,49 +75,53 @@ // Main driver that reads matrix from file and calls the sparse kernel. // func @entry() { - %i0 = arith.constant 0. : f64 + %f0 = arith.constant 0.0 : f64 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - %c5 = arith.constant 5 : index - %c256 = arith.constant 256 : index - // Read the sparse B input from a file. + // Read the sparse input tensor B from a file. %fileName = call @getTensorFilename(%c0) : (index) -> (!Filename) %b = sparse_tensor.new %fileName : !Filename to tensor - // Initialize dense C and D inputs and dense output A. - %cdata = memref.alloc(%c3, %c5) : memref - scf.for %i = %c0 to %c3 step %c1 { - scf.for %j = %c0 to %c5 step %c1 { - %k0 = arith.muli %i, %c5 : index + // Get sizes from B, pick a fixed size for dim-2 of A. + %isz = tensor.dim %b, %c0 : tensor + %jsz = arith.constant 5 : index + %ksz = tensor.dim %b, %c1 : tensor + %lsz = tensor.dim %b, %c2 : tensor + + // Initialize dense input matrix C. + %cdata = memref.alloc(%ksz, %jsz) : memref + scf.for %k = %c0 to %ksz step %c1 { + scf.for %j = %c0 to %jsz step %c1 { + %k0 = arith.muli %k, %jsz : index %k1 = arith.addi %k0, %j : index %k2 = arith.index_cast %k1 : index to i32 - %k = arith.sitofp %k2 : i32 to f64 - memref.store %k, %cdata[%i, %j] : memref + %kf = arith.sitofp %k2 : i32 to f64 + memref.store %kf, %cdata[%k, %j] : memref } } %c = bufferization.to_tensor %cdata : memref - %ddata = memref.alloc(%c4, %c5) : memref - scf.for %i = %c0 to %c4 step %c1 { - scf.for %j = %c0 to %c5 step %c1 { - %k0 = arith.muli %i, %c5 : index + // Initialize dense input matrix D. + %ddata = memref.alloc(%lsz, %jsz) : memref + scf.for %l = %c0 to %lsz step %c1 { + scf.for %j = %c0 to %jsz step %c1 { + %k0 = arith.muli %l, %jsz : index %k1 = arith.addi %k0, %j : index %k2 = arith.index_cast %k1 : index to i32 - %k = arith.sitofp %k2 : i32 to f64 - memref.store %k, %ddata[%i, %j] : memref + %kf = arith.sitofp %k2 : i32 to f64 + memref.store %kf, %ddata[%l, %j] : memref } } %d = bufferization.to_tensor %ddata : memref - %adata = memref.alloc(%c2, %c5) : memref - scf.for %i = %c0 to %c2 step %c1 { - scf.for %j = %c0 to %c5 step %c1 { - memref.store %i0, %adata[%i, %j] : memref + // Initialize dense output matrix A. + %adata = memref.alloc(%isz, %jsz) : memref + scf.for %i = %c0 to %isz step %c1 { + scf.for %j = %c0 to %jsz step %c1 { + memref.store %f0, %adata[%i, %j] : memref } } %a = bufferization.to_tensor %adata : memref @@ -133,7 +137,7 @@ // CHECK: ( 10000, 14225, 19180, 24865, 31280 ) ) // %m = bufferization.to_memref %0 : memref - %v = vector.transfer_read %m[%c0, %c0], %i0 + %v = vector.transfer_read %m[%c0, %c0], %f0 : memref, vector<2x5xf64> vector.print %v : vector<2x5xf64>